diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index faa62f0..febb28c 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -170,7 +170,7 @@ func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) pubkey := s.key.Public() return &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, + {Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, }, }, nil } diff --git a/go.sum b/go.sum index b510424..eee02f9 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= -github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -20,20 +18,12 @@ github.com/google/go-github/v31 v31.0.0 h1:JJUxlP9lFK+ziXKimTCprajMApV1ecWD4NB6C github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/handlers v1.4.2 h1:0QniY0USkHQ1RGCLfKxeNHK9bkDHGRYGNDFBCS+YARg= -github.com/gorilla/handlers v1.4.2/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/handlers v1.5.0 h1:4wjo3sf9azi99c8hTmyaxp9y5S+pFszsy3pP0rAw/lw= github.com/gorilla/handlers v1.5.0/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= -github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY= -github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -58,8 +48,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= @@ -84,7 +72,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab h1:FvshnhkKW+LO3HWHodML8kuVX8rnJTxKm9dFPuI68UM= golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -104,13 +91,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY= -rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= -rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4= -rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 02c5603..3398f51 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -8,10 +8,40 @@ import ( ) const ( + //ScopeOpenID defines the scope `openid` + //OpenID Connect requests MUST contain the `openid` scope value ScopeOpenID = "openid" - ResponseTypeCode ResponseType = "code" - ResponseTypeIDToken ResponseType = "id_token token" + //ScopeProfile defines the scope `profile` + //This (optional) scope value requests access to the End-User's default profile Claims, + //which are: name, family_name, given_name, middle_name, nickname, preferred_username, + //profile, picture, website, gender, birthdate, zoneinfo, locale, and updated_at. + ScopeProfile = "profile" + + //ScopeEmail defines the scope `email` + //This (optional) scope value requests access to the email and email_verified Claims. + ScopeEmail = "email" + + //ScopeAddress defines the scope `address` + //This (optional) scope value requests access to the address Claim. + ScopeAddress = "address" + + //ScopePhone defines the scope `phone` + //This (optional) scope value requests access to the phone_number and phone_number_verified Claims. + ScopePhone = "phone" + + //ScopeOfflineAccess defines the scope `offline_access` + //This (optional) scope value requests that an OAuth 2.0 Refresh Token be issued that can be used to obtain an Access Token + //that grants access to the End-User's UserInfo Endpoint even when the End-User is not present (not logged in). + ScopeOfflineAccess = "offline_access" + + //ResponseTypeCode for the Authorization Code Flow returning a code from the Authorization Server + ResponseTypeCode ResponseType = "code" + + //ResponseTypeIDToken for the Implicit Flow returning id and access tokens directly from the Authorization Server + ResponseTypeIDToken ResponseType = "id_token token" + + //ResponseTypeIDTokenOnly for the Implicit Flow returning only id token directly from the Authorization Server ResponseTypeIDTokenOnly ResponseType = "id_token" DisplayPage Display = "page" @@ -19,13 +49,23 @@ const ( DisplayTouch Display = "touch" DisplayWAP Display = "wap" - PromptNone Prompt = "none" - PromptLogin Prompt = "login" - PromptConsent Prompt = "consent" + //PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages. + //An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed + PromptNone Prompt = "none" + + //PromptLogin (`login`) directs the Authorization Server to prompt the End-User for reauthentication. + PromptLogin Prompt = "login" + + //PromptConsent (`consent`) directs the Authorization Server to prompt the End-User for consent (of sharing information). + PromptConsent Prompt = "consent" + + //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) PromptSelectAccount Prompt = "select_account" + //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow GrantTypeCode GrantType = "authorization_code" + //BearerToken defines the token_type `Bearer`, which is returned in a successful token response BearerToken = "Bearer" ) @@ -38,7 +78,6 @@ var displayValues = map[string]Display{ //AuthRequest according to: //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest -// type AuthRequest struct { ID string Scopes Scopes `schema:"scope"` @@ -63,12 +102,17 @@ type AuthRequest struct { CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` } +//GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI } + +//GetResponseType returns the response_type value for the ErrAuthRequest interface func (a *AuthRequest) GetResponseType() ResponseType { return a.ResponseType } + +//GetState returns the optional state value for the ErrAuthRequest interface func (a *AuthRequest) GetState() string { return a.State } diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 8f2afc2..c468670 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,10 +5,11 @@ import ( "strings" "time" - "github.com/caos/oidc/pkg/utils" "golang.org/x/oauth2" "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/utils" ) type Tokens struct { @@ -61,7 +62,7 @@ type IDTokenClaims struct { type jsonToken struct { Issuer string `json:"iss,omitempty"` Subject string `json:"sub,omitempty"` - Audiences []string `json:"aud,omitempty"` + Audiences interface{} `json:"aud,omitempty"` Expiration int64 `json:"exp,omitempty"` NotBefore int64 `json:"nbf,omitempty"` IssuedAt int64 `json:"iat,omitempty"` @@ -110,13 +111,9 @@ func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &j); err != nil { return err } - audience := j.Audiences - if len(audience) == 1 { - audience = strings.Split(audience[0], " ") - } t.Issuer = j.Issuer t.Subject = j.Subject - t.Audiences = audience + t.Audiences = audienceFromJSON(j.Audiences) t.Expiration = time.Unix(j.Expiration, 0).UTC() t.NotBefore = time.Unix(j.NotBefore, 0).UTC() t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() @@ -161,13 +158,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &i); err != nil { return err } - audience := i.Audiences - if len(audience) == 1 { - audience = strings.Split(audience[0], " ") - } t.Issuer = i.Issuer t.Subject = i.Subject - t.Audiences = audience + t.Audiences = audienceFromJSON(i.Audiences) t.Expiration = time.Unix(i.Expiration, 0).UTC() t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC() t.AuthTime = time.Unix(i.AuthTime, 0).UTC() @@ -247,3 +240,13 @@ func timeToJSON(t time.Time) int64 { } return t.Unix() } + +func audienceFromJSON(audience interface{}) []string { + switch aud := audience.(type) { + case []string: + return aud + case string: + return []string{aud} + } + return nil +} diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 41a7b44..743da68 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -2,13 +2,11 @@ package op import ( "context" - "errors" "fmt" "net/http" "strings" "github.com/gorilla/mux" - "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" @@ -17,33 +15,37 @@ import ( type Authorizer interface { Storage() Storage - Decoder() *schema.Decoder - Encoder() *schema.Encoder + Decoder() utils.Decoder + Encoder() utils.Encoder Signer() Signer IDTokenVerifier() rp.Verifier Crypto() Crypto Issuer() string } -type ValidationAuthorizer interface { +//AuthorizeValidator is an extension of Authorizer interface +//implementing it's own validation mechanism for the auth request +type AuthorizeValidator interface { Authorizer ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) } +//ValidationAuthorizer is an extension of Authorizer interface +//implementing it's own validation mechanism for the auth request +// +//Deprecated: ValidationAuthorizer exists for historical compatibility. Use ValidationAuthorizer itself +type ValidationAuthorizer AuthorizeValidator + +//Authorize handles the authorization request, including +//parsing, validating, storing and finally redirecting to the login handler func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - err := r.ParseForm() + authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder()) - return - } - authReq := new(oidc.AuthRequest) - err = authorizer.Decoder().Decode(authReq, r.Form) - if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } validation := ValidateAuthRequest - if validater, ok := authorizer.(ValidationAuthorizer); ok { + if validater, ok := authorizer.(AuthorizeValidator); ok { validation = validater.ValidateAuthRequest } userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier()) @@ -64,6 +66,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } +func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, ErrInvalidRequest("cannot parse form") + } + authReq := new(oidc.AuthRequest) + err = decoder.Decode(authReq, r.Form) + if err != nil { + return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)) + } + return authReq, nil +} + func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { client, err := storage.GetClientByClientID(ctx, authReq.ClientID) if err != nil { @@ -95,7 +110,6 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res if uri == "" { return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") } - if !utils.Contains(client.RedirectURIs(), uri) { return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") } @@ -138,7 +152,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie if idTokenHint == "" { return "", nil } - claims, err := verifier.Verify(ctx, "", idTokenHint) + claims, err := verifier.VerifyIDToken(ctx, idTokenHint) if err != nil { return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") } @@ -160,7 +174,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author return } if !authReq.Done() { - AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, ErrInteractionRequired("Unfortunately, the user may is not logged in and/or additional interaction is required."), authorizer.Encoder()) return } AuthResponse(authReq, authorizer, w, r) diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index 1b31fad..343924f 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -3,66 +3,140 @@ package op_test import ( "net/http" "net/http/httptest" - "strings" + "net/url" + "reflect" "testing" + "github.com/gorilla/schema" "github.com/stretchr/testify/require" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op/mock" "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" ) -func TestAuthorize(t *testing.T) { - // testCallback := func(t *testing.T, clienID string) callbackHandler { - // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { - // // require.Equal(t, clientID, client.) - // } - // } - // testErr := func(t *testing.T, expected error) errorHandler { - // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { - // require.Equal(t, expected, err) - // } - // } +// +//TOOD: tests will be implemented in branch for service accounts +//func TestAuthorize(t *testing.T) { +// // testCallback := func(t *testing.T, clienID string) callbackHandler { +// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { +// // // require.Equal(t, clientID, client.) +// // } +// // } +// // testErr := func(t *testing.T, expected error) errorHandler { +// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { +// // require.Equal(t, expected, err) +// // } +// // } +// type args struct { +// w http.ResponseWriter +// r *http.Request +// authorizer op.Authorizer +// } +// tests := []struct { +// name string +// args args +// }{ +// { +// "parsing fails", +// args{ +// httptest.NewRecorder(), +// &http.Request{Method: "POST", Body: nil}, +// mock.NewAuthorizerExpectValid(t, true), +// // testCallback(t, ""), +// // testErr(t, ErrInvalidRequest("cannot parse form")), +// }, +// }, +// { +// "decoding fails", +// args{ +// httptest.NewRecorder(), +// func() *http.Request { +// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) +// r.Header.Set("Content-Type", "application/x-www-form-urlencoded") +// return r +// }(), +// mock.NewAuthorizerExpectValid(t, true), +// // testCallback(t, ""), +// // testErr(t, ErrInvalidRequest("cannot parse auth request")), +// }, +// }, +// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) +// }) +// } +//} + +func TestParseAuthorizeRequest(t *testing.T) { type args struct { - w http.ResponseWriter - r *http.Request - authorizer op.Authorizer + r *http.Request + decoder utils.Decoder + } + type res struct { + want *oidc.AuthRequest + err bool } tests := []struct { name string args args + res res }{ { - "parsing fails", + "parsing form error", args{ - httptest.NewRecorder(), - &http.Request{Method: "POST", Body: nil}, - mock.NewAuthorizerExpectValid(t, true), - // testCallback(t, ""), - // testErr(t, ErrInvalidRequest("cannot parse form")), + &http.Request{URL: &url.URL{RawQuery: "invalid=%%param"}}, + schema.NewDecoder(), + }, + res{ + nil, + true, }, }, { - "decoding fails", + "decoding error", args{ - httptest.NewRecorder(), - func() *http.Request { - r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - return r + &http.Request{URL: &url.URL{RawQuery: "unknown=value"}}, + func() utils.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(false) + return decoder }(), - mock.NewAuthorizerExpectValid(t, true), - // testCallback(t, ""), - // testErr(t, ErrInvalidRequest("cannot parse auth request")), + }, + res{ + nil, + true, + }, + }, + { + "parsing ok", + args{ + &http.Request{URL: &url.URL{RawQuery: "scope=openid"}}, + func() utils.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(false) + return decoder + }(), + }, + res{ + &oidc.AuthRequest{Scopes: oidc.Scopes{"openid"}}, + false, }, }, - // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) + got, err := op.ParseAuthorizeRequest(tt.args.r, tt.args.decoder) + if (err != nil) != tt.res.err { + t.Errorf("ParseAuthorizeRequest() error = %v, wantErr %v", err, tt.res.err) + } + if !reflect.DeepEqual(got, tt.res.want) { + t.Errorf("ParseAuthorizeRequest() got = %v, want %v", got, tt.res.want) + } }) } } diff --git a/pkg/op/client.go b/pkg/op/client.go index 9e8a7dc..ef9b62e 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -1,8 +1,9 @@ package op import ( - "github.com/caos/oidc/pkg/oidc" "time" + + "github.com/caos/oidc/pkg/oidc" ) const ( diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index 348872f..79173fb 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -1,8 +1,9 @@ package op -import "testing" - -import "os" +import ( + "os" + "testing" +) func TestValidateIssuer(t *testing.T) { type args struct { @@ -78,7 +79,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) { wantErr bool }{ { - "localhost with http ok", + "localhost with http with dev ok", args{"http://localhost:9999"}, false, }, diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index a42da6a..9d18dd0 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -13,6 +13,7 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" ) const ( @@ -254,11 +255,11 @@ func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignat return payload, err } -func (p *DefaultOP) Decoder() *schema.Decoder { +func (p *DefaultOP) Decoder() utils.Decoder { return p.decoder } -func (p *DefaultOP) Encoder() *schema.Encoder { +func (p *DefaultOP) Encoder() utils.Encoder { return p.encoder } diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 39b39bc..c14fac4 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -6,12 +6,13 @@ import ( "reflect" "testing" - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/op" - "github.com/caos/oidc/pkg/op/mock" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/op/mock" ) func TestDiscover(t *testing.T) { @@ -147,7 +148,7 @@ func TestSupportedClaims(t *testing.T) { } func Test_SigAlgorithms(t *testing.T) { - m := mock.NewMockSigner(gomock.NewController((t))) + m := mock.NewMockSigner(gomock.NewController(t)) type args struct { s op.Signer } @@ -199,7 +200,7 @@ func Test_SubjectTypes(t *testing.T) { } func Test_AuthMethods(t *testing.T) { - m := mock.NewMockConfiguration(gomock.NewController((t))) + m := mock.NewMockConfiguration(gomock.NewController(t)) type args struct { c op.Configuration } diff --git a/pkg/op/error.go b/pkg/op/error.go index f3c5857..06935c9 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -4,15 +4,15 @@ import ( "fmt" "net/http" - "github.com/gorilla/schema" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) const ( - InvalidRequest errorType = "invalid_request" - ServerError errorType = "server_error" + InvalidRequest errorType = "invalid_request" + InvalidRequestURI errorType = "invalid_request_uri" + InteractionRequired errorType = "interaction_required" + ServerError errorType = "server_error" ) var ( @@ -24,11 +24,17 @@ var ( } ErrInvalidRequestRedirectURI = func(description string) *OAuthError { return &OAuthError{ - ErrorType: InvalidRequest, + ErrorType: InvalidRequestURI, Description: description, redirectDisabled: true, } } + ErrInteractionRequired = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: InteractionRequired, + Description: description, + } + } ErrServerError = func(description string) *OAuthError { return &OAuthError{ ErrorType: ServerError, @@ -45,7 +51,7 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) { if authReq == nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -56,7 +62,7 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq e.ErrorType = ServerError e.Description = err.Error() } - e.state = authReq.GetState() + e.State = authReq.GetState() if authReq.GetRedirectURI() == "" || e.redirectDisabled { http.Error(w, e.Description, http.StatusBadRequest) return @@ -89,9 +95,9 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error) { type OAuthError struct { ErrorType errorType `json:"error" schema:"error"` - Description string `json:"error_description" schema:"error_description"` - state string `json:"state" schema:"state"` - redirectDisabled bool + Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` + State string `json:"state,omitempty" schema:"state,omitempty"` + redirectDisabled bool `json:"-" schema:"-"` } func (e *OAuthError) Error() string { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index dbfc2a6..5272997 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -7,8 +7,8 @@ package mock import ( op "github.com/caos/oidc/pkg/op" rp "github.com/caos/oidc/pkg/rp" + utils "github.com/caos/oidc/pkg/utils" gomock "github.com/golang/mock/gomock" - schema "github.com/gorilla/schema" reflect "reflect" ) @@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call { } // Decoder mocks base method -func (m *MockAuthorizer) Decoder() *schema.Decoder { +func (m *MockAuthorizer) Decoder() utils.Decoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Decoder") - ret0, _ := ret[0].(*schema.Decoder) + ret0, _ := ret[0].(utils.Decoder) return ret0 } @@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call { } // Encoder mocks base method -func (m *MockAuthorizer) Encoder() *schema.Encoder { +func (m *MockAuthorizer) Encoder() utils.Encoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Encoder") - ret0, _ := ret[0].(*schema.Encoder) + ret0, _ := ret[0].(utils.Encoder) return ret0 } diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 29c9354..202889c 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -69,6 +69,9 @@ type Verifier struct{} func (v *Verifier) Verify(ctx context.Context, accessToken, idToken string) (*oidc.IDTokenClaims, error) { return nil, nil } +func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDTokenClaims, error) { + return nil, nil +} type Sig struct{} diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index b0d0dca..eed21d5 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -1,12 +1,12 @@ package mock import ( - "github.com/caos/oidc/pkg/oidc" "testing" - gomock "github.com/golang/mock/gomock" + "github.com/golang/mock/gomock" - op "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" ) func NewClient(t *testing.T) op.Client { diff --git a/pkg/op/session.go b/pkg/op/session.go index c274bf0..e60f71b 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -6,11 +6,11 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" - "github.com/gorilla/schema" + "github.com/caos/oidc/pkg/utils" ) type SessionEnder interface { - Decoder() *schema.Decoder + Decoder() utils.Decoder Storage() Storage IDTokenVerifier() rp.Verifier DefaultLogoutRedirectURI() string @@ -39,7 +39,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { http.Redirect(w, r, session.RedirectURI, http.StatusFound) } -func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) { +func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) { err := r.ParseForm() if err != nil { return nil, ErrInvalidRequest("error parsing form") @@ -57,7 +57,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, if req.IdTokenHint == "" { return session, nil } - claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint) + claims, err := ender.IDTokenVerifier().VerifyIDToken(ctx, req.IdTokenHint) if err != nil { return nil, ErrInvalidRequest("id_token_hint invalid") } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index b4f770e..a313934 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -8,6 +8,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/caos/logging" + "github.com/caos/oidc/pkg/oidc" ) diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index f924531..71bd1ec 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -5,8 +5,6 @@ import ( "errors" "net/http" - "github.com/gorilla/schema" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) @@ -14,7 +12,7 @@ import ( type Exchanger interface { Issuer() string Storage() Storage - Decoder() *schema.Decoder + Decoder() utils.Decoder Signer() Signer Crypto() Crypto AuthMethodPostSupported() bool @@ -42,7 +40,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) { +func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) { err := r.ParseForm() if err != nil { return nil, ErrInvalidRequest("error parsing form") diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 8f55b15..88ba955 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -5,14 +5,12 @@ import ( "net/http" "strings" - "github.com/gorilla/schema" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) type UserinfoProvider interface { - Decoder() *schema.Decoder + Decoder() utils.Decoder Crypto() Crypto Storage() Storage } @@ -37,7 +35,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP utils.MarshalJSON(w, info) } -func getAccessToken(r *http.Request, decoder *schema.Decoder) (string, error) { +func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) { authHeader := r.Header.Get("authorization") if authHeader != "" { parts := strings.Split(authHeader, "Bearer ") diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index db599e3..dfdf134 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -158,14 +158,8 @@ func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString return idToken, nil } -func (v *DefaultVerifier) now() time.Time { - if v.config.now.IsZero() { - v.config.now = time.Now().UTC().Round(time.Second) - } - return v.config.now -} - -//VerifyIDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +//Verify implements the `VerifyIDToken` method of the `Verifier` interface +//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { //1. if encrypted --> decrypt decrypted, err := v.decryptToken(idTokenString) @@ -227,6 +221,13 @@ func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString strin return claims, nil } +func (v *DefaultVerifier) now() time.Time { + if v.config.now.IsZero() { + v.config.now = time.Now().UTC().Round(time.Second) + } + return v.config.now +} + func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -372,7 +373,7 @@ func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) { } func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { - if accessToken == "" { + if atHash == "" { return nil } diff --git a/pkg/rp/mock/generate.go b/pkg/rp/mock/generate.go new file mode 100644 index 0000000..71bc3be --- /dev/null +++ b/pkg/rp/mock/generate.go @@ -0,0 +1,3 @@ +package mock + +//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/caos/oidc/pkg/rp Verifier diff --git a/pkg/rp/mock/verifier.mock.go b/pkg/rp/mock/verifier.mock.go new file mode 100644 index 0000000..acd7d77 --- /dev/null +++ b/pkg/rp/mock/verifier.mock.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/rp (interfaces: Verifier) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + oidc "github.com/caos/oidc/pkg/oidc" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockVerifier is a mock of Verifier interface +type MockVerifier struct { + ctrl *gomock.Controller + recorder *MockVerifierMockRecorder +} + +// MockVerifierMockRecorder is the mock recorder for MockVerifier +type MockVerifierMockRecorder struct { + mock *MockVerifier +} + +// NewMockVerifier creates a new mock instance +func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier { + mock := &MockVerifier{ctrl: ctrl} + mock.recorder = &MockVerifierMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder { + return m.recorder +} + +// Verify mocks base method +func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2) + ret0, _ := ret[0].(*oidc.IDTokenClaims) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verify indicates an expected call of Verify +func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2) +} + +// VerifyIDToken mocks base method +func (m *MockVerifier) VerifyIDToken(arg0 context.Context, arg1 string) (*oidc.IDTokenClaims, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "VerifyIDToken", arg0, arg1) + ret0, _ := ret[0].(*oidc.IDTokenClaims) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// VerifyIDToken indicates an expected call of VerifyIDToken +func (mr *MockVerifierMockRecorder) VerifyIDToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyIDToken", reflect.TypeOf((*MockVerifier)(nil).VerifyIDToken), arg0, arg1) +} diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go new file mode 100644 index 0000000..53b2f03 --- /dev/null +++ b/pkg/rp/mock/verifier.mock.impl.go @@ -0,0 +1,37 @@ +package mock + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" +) + +func NewVerifier(t *testing.T) rp.Verifier { + return NewMockVerifier(gomock.NewController(t)) +} + +func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier { + m := NewVerifier(t) + ExpectVerifyInvalid(m) + return m +} + +func ExpectVerifyInvalid(v rp.Verifier) { + mock := v.(*MockVerifier) + mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid")) +} + +func NewMockVerifierExpectValid(t *testing.T) rp.Verifier { + m := NewVerifier(t) + ExpectVerifyValid(m) + return m +} + +func ExpectVerifyValid(v rp.Verifier) { + mock := v.(*MockVerifier) + mock.EXPECT().VerifyIDToken(gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil) +} diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go index b82e6c2..5add60f 100644 --- a/pkg/rp/verifier.go +++ b/pkg/rp/verifier.go @@ -12,4 +12,7 @@ type Verifier interface { //Verify checks the access_token and id_token and returns the `id token claims` Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) + + //VerifyIDToken checks the id_token only and returns its `id token claims` + VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) } diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 6ad7083..b3ed631 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -18,6 +18,13 @@ var ( } ) +type Decoder interface { + Decode(dst interface{}, src map[string][]string) error +} +type Encoder interface { + Encode(src interface{}, dst map[string][]string) error +} + func FormRequest(endpoint string, request interface{}) (*http.Request, error) { form := make(map[string][]string) encoder := schema.NewEncoder() @@ -56,7 +63,7 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e return nil } -func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) { +func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) { values := make(map[string][]string) err := encoder.Encode(resp, values) if err != nil {