improve Loopback check

This commit is contained in:
Livio Amstutz 2021-04-29 12:43:21 +02:00
parent 72fc86164c
commit 540a7bd7be
2 changed files with 33 additions and 15 deletions

View file

@ -3,6 +3,7 @@ package op
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -184,7 +185,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
//ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
parsedURL, isLoopback := LoopbackOrLocalhost(uri) parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://") isCustomSchema := !strings.HasPrefix(uri, "http://")
if utils.Contains(client.RedirectURIs(), uri) { if utils.Contains(client.RedirectURIs(), uri) {
if isLoopback || isCustomSchema { if isLoopback || isCustomSchema {
@ -196,7 +197,7 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
} }
for _, uri := range client.RedirectURIs() { for _, uri := range client.RedirectURIs() {
redirectURI, ok := LoopbackOrLocalhost(uri) redirectURI, ok := HTTPLoopbackOrLocalhost(uri)
if ok && equalURI(parsedURL, redirectURI) { if ok && equalURI(parsedURL, redirectURI) {
return nil return nil
} }
@ -208,16 +209,16 @@ func equalURI(url1, url2 *url.URL) bool {
return url1.Path == url2.Path && url1.RawQuery == url2.RawQuery return url1.Path == url2.Path && url1.RawQuery == url2.RawQuery
} }
func LoopbackOrLocalhost(rawurl string) (*url.URL, bool) { func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) {
parsedURL, err := url.Parse(rawurl) parsedURL, err := url.Parse(rawurl)
if err != nil { if err != nil {
return nil, false return nil, false
} }
if parsedURL.Scheme != "http" {
return nil, false
}
hostName := parsedURL.Hostname() hostName := parsedURL.Hostname()
return parsedURL, parsedURL.Scheme == "http" && return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback()
hostName == "localhost" ||
hostName == "127.0.0.1" ||
hostName == "::1"
} }
//ValidateAuthReqResponseType validates the passed response_type to the registered response types //ValidateAuthReqResponseType validates the passed response_type to the registered response types

View file

@ -316,16 +316,16 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
false, false,
}, },
{ {
"code flow registered http not confidential (user agent) fails", "code flow registered http not confidential (native) fails",
args{"http://registered.com/callback", args{"http://registered.com/callback",
mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false), mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeNative, nil, false),
oidc.ResponseTypeCode}, oidc.ResponseTypeCode},
true, true,
}, },
{ {
"code flow registered http not confidential (native) fails", "code flow registered http not confidential (user agent) fails",
args{"http://registered.com/callback", args{"http://registered.com/callback",
mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeNative, nil, false), mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false),
oidc.ResponseTypeCode}, oidc.ResponseTypeCode},
true, true,
}, },
@ -344,7 +344,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
false, false,
}, },
{ {
"code flow registered http localhost native ok", "code flow registered http loopback v6 native ok",
args{"http://[::1]:4200/callback", args{"http://[::1]:4200/callback",
mock.NewClientWithConfig(t, []string{"http://[::1]/callback"}, op.ApplicationTypeNative, nil, false), mock.NewClientWithConfig(t, []string{"http://[::1]/callback"}, op.ApplicationTypeNative, nil, false),
oidc.ResponseTypeCode}, oidc.ResponseTypeCode},
@ -420,6 +420,13 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
oidc.ResponseTypeIDToken}, oidc.ResponseTypeIDToken},
false, false,
}, },
{
"implicit flow registered http localhost web fails",
args{"http://localhost:9999/callback",
mock.NewClientWithConfig(t, []string{"http://localhost:9999/callback"}, op.ApplicationTypeWeb, nil, false),
oidc.ResponseTypeIDToken},
true,
},
{ {
"implicit flow registered http localhost user agent fails", "implicit flow registered http localhost user agent fails",
args{"http://localhost:9999/callback", args{"http://localhost:9999/callback",
@ -581,10 +588,15 @@ func Test_LoopbackOrLocalhost(t *testing.T) {
true, true,
}, },
{ {
"v6 no port ok", "v6 short no port ok",
args{url: "http://[::1]/test"}, args{url: "http://[::1]/test"},
true, true,
}, },
{
"v6 long no port ok",
args{url: "http://[0:0:0:0:0:0:0:1]/test"},
true,
},
{ {
"locahost no port ok", "locahost no port ok",
args{url: "http://localhost/test"}, args{url: "http://localhost/test"},
@ -596,10 +608,15 @@ func Test_LoopbackOrLocalhost(t *testing.T) {
true, true,
}, },
{ {
"v6 with port ok", "v6 short with port ok",
args{url: "http://[::1]:4200/test"}, args{url: "http://[::1]:4200/test"},
true, true,
}, },
{
"v6 long with port ok",
args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
true,
},
{ {
"localhost with port ok", "localhost with port ok",
args{url: "http://localhost:4200/test"}, args{url: "http://localhost:4200/test"},
@ -608,7 +625,7 @@ func Test_LoopbackOrLocalhost(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if _, got := op.LoopbackOrLocalhost(tt.args.url); got != tt.want { if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want {
t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want) t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want)
} }
}) })