From 310220d38ed6d428e67575250e1bfefb482fdf92 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Fri, 6 Dec 2019 10:42:17 +0100 Subject: [PATCH] refactoring --- example/internal/mock/storage.go | 32 +++++-- pkg/oidc/authorization.go | 2 + pkg/oidc/token.go | 35 ++++--- pkg/op/authrequest.go | 35 +++---- pkg/op/authrequest_test.go | 2 +- pkg/op/default_op.go | 5 +- pkg/op/discovery.go | 10 +- pkg/op/discovery_test.go | 158 +++++++++++++++++++++++++++++-- pkg/op/error.go | 100 ++++++++----------- pkg/op/mock/authorizer.mock.go | 14 +++ pkg/op/op.go | 2 - pkg/op/storage.go | 2 + pkg/op/token.go | 46 +++++++++ pkg/op/tokenrequest.go | 35 +------ pkg/rp/default_rp.go | 5 +- pkg/rp/default_verifier.go | 2 +- pkg/utils/http.go | 10 ++ 17 files changed, 346 insertions(+), 149 deletions(-) create mode 100644 pkg/op/token.go diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 9b8ba5e..c8eae4b 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -41,7 +41,9 @@ func (a *AuthRequest) GetACR() string { } func (a *AuthRequest) GetAMR() []string { - return []string{} + return []string{ + "password", + } } func (a *AuthRequest) GetAudience() []string { @@ -55,7 +57,11 @@ func (a *AuthRequest) GetAuthTime() time.Time { } func (a *AuthRequest) GetClientID() string { - return "" + return a.ID +} + +func (a *AuthRequest) GetCode() string { + return "code" } func (a *AuthRequest) GetID() string { @@ -63,23 +69,31 @@ func (a *AuthRequest) GetID() string { } func (a *AuthRequest) GetNonce() string { - return "" + return "nonce" } func (a *AuthRequest) GetRedirectURI() string { - return "" + return "http://localhost:5556/auth/callback" } func (a *AuthRequest) GetResponseType() oidc.ResponseType { return a.ResponseType } +func (a *AuthRequest) GetScopes() []string { + return []string{ + "openid", + "profile", + "email", + } +} + func (a *AuthRequest) GetState() string { return "" } func (a *AuthRequest) GetSubject() string { - return "" + return "sub" } func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { @@ -132,11 +146,14 @@ func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil } +func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) { + return s.key, nil +} func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) { pubkey := s.key.Public() return jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ - jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256"}, + jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, }, }, nil } @@ -151,6 +168,9 @@ func (c *ConfClient) RedirectURIs() []string { "http://localhost:9999/callback", "http://localhost:5556/auth/callback", "custom://callback", + "https://localhost:8443/test/a/instructions-example/callback", + "https://op.certification.openid.net:62054/authz_cb", + "https://op.certification.openid.net:62054/authz_post", } } diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index d3b8505..0ebd26d 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -25,6 +25,8 @@ const ( PromptSelectAccount = "select_account" GrantTypeCode GrantType = "authorization_code" + + BearerToken = "Bearer" ) var displayValues = map[string]Display{ diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 4a71847..248f671 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -14,17 +14,18 @@ import ( ) type IDTokenClaims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Audiences []string `json:"aud,omitempty"` - Expiration time.Time `json:"exp,omitempty"` - IssuedAt time.Time `json:"iat,omitempty"` - AuthTime time.Time `json:"auth_time,omitempty"` - Nonce string `json:"nonce,omitempty"` - AuthenticationContextClassReference string `json:"acr,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - AuthorizedParty string `json:"azp,omitempty"` - AccessTokenHash string `json:"at_hash,omitempty"` + Issuer string + Subject string + Audiences []string + Expiration time.Time + IssuedAt time.Time + AuthTime time.Time + Nonce string + AuthenticationContextClassReference string + AuthenticationMethodsReferences []string + AuthorizedParty string + AccessTokenHash string + CodeHash string Signature jose.SignatureAlgorithm //TODO: ??? } @@ -46,6 +47,7 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences t.AuthorizedParty = i.AuthorizedParty t.AccessTokenHash = i.AccessTokenHash + t.CodeHash = i.CodeHash return nil } @@ -63,6 +65,7 @@ func (t *IDTokenClaims) MarshalJSON() ([]byte, error) { AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, AuthorizedParty: t.AuthorizedParty, AccessTokenHash: t.AccessTokenHash, + CodeHash: t.CodeHash, } return json.Marshal(j) } @@ -81,21 +84,23 @@ type jsonIDToken struct { AuthenticationMethodsReferences []string `json:"amr,omitempty"` AuthorizedParty string `json:"azp,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` } type Tokens struct { *oauth2.Token IDTokenClaims *IDTokenClaims + IDToken string } -func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { - tokenHash, err := getHashAlgorithm(sigAlgorithm) +func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { + hash, err := getHashAlgorithm(sigAlgorithm) if err != nil { return "", err } - tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error - sum := tokenHash.Sum(nil)[:tokenHash.Size()/2] + hash.Write([]byte(claim)) // hash documents that Write will never return an error + sum := hash.Sum(nil)[:hash.Size()/2] return base64.RawURLEncoding.EncodeToString(sum), nil } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 6c5610c..8c45e4e 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -3,7 +3,6 @@ package op import ( "fmt" "net/http" - "net/url" "strings" "time" @@ -19,6 +18,7 @@ type Authorizer interface { Decoder() *schema.Decoder Encoder() *schema.Encoder Signer() Signer + Issuer() string // ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) } @@ -37,7 +37,7 @@ type ValidationAuthorizer interface { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { err := r.ParseForm() if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form")) + AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder()) // AuthRequestError(w, r, nil, ) return } @@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { err = authorizer.Decoder().Decode(authReq, r.Form) if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err))) + AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder()) return } @@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { validation = validater.ValidateAuthRequest } if err := validation(authReq, authorizer.Storage()); err != nil { - AuthRequestError(w, r, authReq, err) + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } req, err := authorizer.Storage().CreateAuthRequest(authReq) if err != nil { - AuthRequestError(w, r, authReq, err) + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } client, err := authorizer.Storage().GetClientByClientID(req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, err) + AuthRequestError(w, r, req, err, authorizer.Encoder()) return } RedirectToLogin(req.GetID(), client, w, r) @@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error { return nil } -func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error { +func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { if uri == "" { return ErrInvalidRequest("redirect_uri must not be empty") } @@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author authReq, err := authorizer.Storage().AuthRequestByID(id) if err != nil { - AuthRequestError(w, r, nil, err) + AuthRequestError(w, r, nil, err, authorizer.Encoder()) return } AuthResponse(authReq, authorizer, w, r) @@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { var callback string if authReq.GetResponseType() == oidc.ResponseTypeCode { - callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test") + callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode()) } else { var accessToken string var err error + var exp uint64 if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { - accessToken, err = CreateAccessToken() + accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer()) if err != nil { } } - idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer()) + idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer()) if err != nil { } resp := &oidc.AccessTokenResponse{ AccessToken: accessToken, IDToken: idToken, - TokenType: "Bearer", + TokenType: oidc.BearerToken, + ExpiresIn: exp, } - values := make(map[string][]string) - authorizer.Encoder().Encode(resp, values) - v := url.Values(values) - callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode()) + params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) + if err != nil { + + } + callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) } http.Redirect(w, r, callback, http.StatusFound) } diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index 2374aa6..3c4c1e6 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { uri string clientID string responseType oidc.ResponseType - storage op.Storage + storage op.OPStorage } tests := []struct { name string diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 98affbc..1890620 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -18,6 +18,8 @@ const ( authMethodBasic = "client_secret_basic" authMethodPost = "client_secret_post" + + DefaultIDTokenValidity = time.Duration(5 * time.Minute) ) var ( @@ -28,7 +30,6 @@ var ( Userinfo: defaultUserinfoEndpoint, JwksURI: defaultKeysEndpoint, } - DefaultIDTokenValidity = time.Duration(5 * time.Minute) ) type DefaultOP struct { @@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request) } func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { - + w.Write([]byte("ok")) } diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 411d668..7c7eae5 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati // ClaimsSupported: oidc.SupportedClaims, IDTokenSigningAlgValuesSupported: sigAlgorithms(s), SubjectTypesSupported: subjectTypes(c), - TokenEndpointAuthMethodsSupported: authMethods(c), + TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()), } } @@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string { return []string{"public"} //TODO: config } -func authMethods(c Configuration) []string { +func authMethods(basic, post bool) []string { authMethods := make([]string, 0, 2) - if c.AuthMethodBasicSupported() { + if basic { + // if c.AuthMethodBasicSupported() { authMethods = append(authMethods, authMethodBasic) } - if c.AuthMethodPostSupported() { + if post { + // if c.AuthMethodPostSupported() { authMethods = append(authMethods, authMethodPost) } return authMethods diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index bfc2360..d4ed4f1 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -1,4 +1,4 @@ -package op_test +package op import ( "net/http" @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/op" ) func TestDiscover(t *testing.T) { @@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - op.Discover(tt.args.w, tt.args.config) + Discover(tt.args.w, tt.args.config) rec := tt.args.w.(*httptest.ResponseRecorder) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String()) @@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - c op.Configuration - s op.Signer + c Configuration + s Signer } tests := []struct { name string @@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) { + if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) { t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want) } }) } } + +func Test_scopes(t *testing.T) { + type args struct { + c Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("scopes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_responseTypes(t *testing.T) { + type args struct { + c Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("responseTypes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_grantTypes(t *testing.T) { + type args struct { + c Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("grantTypes() = %v, want %v", got, tt.want) + } + }) + } +} + +// func Test_sigAlgorithms(t *testing.T) { +// type args struct { +// s Signer +// } +// tests := []struct { +// name string +// args args +// want []string +// }{ +// { +// "", +// args{}, +// []string{"RS256"}, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) { +// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want) +// } +// }) +// } +// } + +// func Test_subjectTypes(t *testing.T) { +// type args struct { +// c Configuration +// } +// tests := []struct { +// name string +// args args +// want []string +// }{ +// { +// "none", +// args{func()} +// } +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { +// t.Errorf("subjectTypes() = %v, want %v", got, tt.want) +// } +// }) +// } +// } + +func Test_authMethods(t *testing.T) { + type args struct { + basic bool + post bool + } + tests := []struct { + name string + args args + want []string + }{ + { + "none", + args{false, false}, + []string{}, + }, + { + "basic", + args{true, false}, + []string{authMethodBasic}, + }, + { + "post", + args{false, true}, + []string{authMethodPost}, + }, + { + "basic and post", + args{true, true}, + []string{authMethodBasic, authMethodPost}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) { + t.Errorf("authMethods() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/op/error.go b/pkg/op/error.go index 3a9fa2f..bd9bc00 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,8 +1,10 @@ package op import ( + "fmt" "net/http" - "net/url" + + "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" @@ -13,6 +15,21 @@ const ( ServerError errorType = "server_error" ) +var ( + ErrInvalidRequest = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: InvalidRequest, + Description: description, + } + } + ErrServerError = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: ServerError, + Description: description, + } + } +) + type errorType string type ErrAuthRequest interface { @@ -21,7 +38,7 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) { if authReq == nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq http.Error(w, err.Error(), http.StatusBadRequest) return } + e, ok := err.(*OAuthError) + if !ok { + e = new(OAuthError) + e.ErrorType = ServerError + e.Description = err.Error() + } + e.state = authReq.GetState() + params, err := utils.URLEncodeResponse(e, encoder) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } url := authReq.GetRedirectURI() if authReq.GetResponseType() == oidc.ResponseTypeCode { - url += "?" + url += "?" + params } else { - url += "#" - } - var errorType errorType - var description string - if e, ok := err.(*OAuthError); ok { - errorType = e.ErrorType - description = e.Description - } else { - errorType = ServerError - description = err.Error() - } - url += "error=" + string(errorType) - if description != "" { - url += "&error_description=" + description - } - if authReq.GetState() != "" { - url += "&state=" + authReq.GetState() + url += "#" + params } http.Redirect(w, r, url, http.StatusFound) } @@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { } type OAuthError struct { - ErrorType errorType `json:"error"` - Description string `json:"description"` -} - -var ( - ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError { - return &OAuthError{ - ErrorType: InvalidRequest, - Description: description, - } - } - ErrServerError = func(description string, args ...interface{}) *OAuthError { - return &OAuthError{ - ErrorType: ServerError, - Description: description, - } - } -) - -func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) { - if authReq == nil { - http.Error(w, e.Error(), http.StatusBadRequest) - return - } - if authReq.GetRedirectURI() == "" { - http.Error(w, e.Error(), http.StatusBadRequest) - return - } - callback := authReq.GetRedirectURI() - if authReq.GetResponseType() == oidc.ResponseTypeCode { - callback += "?" - } else { - callback += "#" - } - callback += "error=" + string(e.ErrorType) - if e.Description != "" { - callback += "&error_description=" + url.QueryEscape(e.Description) - } - if authReq.GetState() != "" { - callback += "&state=" + authReq.GetState() - } - http.Redirect(w, r, callback, http.StatusFound) + ErrorType errorType `json:"error" schema:"error"` + Description string `json:"description" schema:"description"` + state string `json:"state" schema:"state"` } func (e *OAuthError) Error() string { - return "" + return fmt.Sprintf("%s: %s", e.ErrorType, e.Description) } diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index d051b71..55c8c21 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder)) } +// Issuer mocks base method +func (m *MockAuthorizer) Issuer() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Issuer") + ret0, _ := ret[0].(string) + return ret0 +} + +// Issuer indicates an expected call of Issuer +func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer)) +} + // Signer mocks base method func (m *MockAuthorizer) Signer() op.Signer { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index d5d96ab..7db2ff4 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -12,14 +12,12 @@ import ( type OpenIDProvider interface { Configuration - // Storage() Storage HandleDiscovery(w http.ResponseWriter, r *http.Request) HandleAuthorize(w http.ResponseWriter, r *http.Request) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request) - // Storage() Storage HttpHandler() *http.Server } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index f4fa274..e90d5e7 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -36,9 +36,11 @@ type AuthRequest interface { GetAudience() []string GetAuthTime() time.Time GetClientID() string + GetCode() string GetNonce() string GetRedirectURI() string GetResponseType() oidc.ResponseType + GetScopes() []string GetState() string GetSubject() string } diff --git a/pkg/op/token.go b/pkg/op/token.go new file mode 100644 index 0000000..fd759b2 --- /dev/null +++ b/pkg/op/token.go @@ -0,0 +1,46 @@ +package op + +import ( + "fmt" + "time" + + "github.com/caos/oidc/pkg/oidc" +) + +func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) { + var err error + accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes()) + exp := time.Duration(5 * time.Minute) + return accessToken, uint64(exp.Seconds()), err +} + +func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) { + var err error + exp := time.Now().UTC().Add(validity) + claims := &oidc.IDTokenClaims{ + Issuer: issuer, + Subject: authReq.GetSubject(), + Audiences: authReq.GetAudience(), + Expiration: exp, + IssuedAt: time.Now().UTC(), + AuthTime: authReq.GetAuthTime(), + Nonce: authReq.GetNonce(), + AuthenticationContextClassReference: authReq.GetACR(), + AuthenticationMethodsReferences: authReq.GetAMR(), + AuthorizedParty: authReq.GetClientID(), + } + if accessToken != "" { + claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) + if err != nil { + return "", err + } + } + if code != "" { + claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) + if err != nil { + return "", err + } + } + + return signer.SignIDToken(claims) +} diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index db00746..00ec3e6 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ExchangeRequestError(w, r, err) return } - accessToken, err := CreateAccessToken() + accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer()) if err != nil { ExchangeRequestError(w, r, err) return } - idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer()) + idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer()) if err != nil { ExchangeRequestError(w, r, err) return @@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { resp := &oidc.AccessTokenResponse{ AccessToken: accessToken, IDToken: idToken, + TokenType: oidc.BearerToken, + ExpiresIn: exp, } utils.MarshalJSON(w, resp) } -func CreateAccessToken() (string, error) { - return "accessToken", nil -} - -func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) { - var err error - exp := time.Now().UTC().Add(validity) - claims := &oidc.IDTokenClaims{ - Issuer: issuer, - Subject: authReq.GetSubject(), - Audiences: authReq.GetAudience(), - Expiration: exp, - IssuedAt: time.Now().UTC(), - AuthTime: authReq.GetAuthTime(), - Nonce: authReq.GetNonce(), - AuthenticationContextClassReference: authReq.GetACR(), - AuthenticationMethodsReferences: authReq.GetAMR(), - AuthorizedParty: authReq.GetClientID(), - } - if accessToken != "" { - claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm()) - if err != nil { - return "", err - } - } - - return signer.SignIDToken(claims) -} - func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) { if tokenReq.ClientID == "" { if !exchanger.AuthMethodBasicSupported() { diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go index dce6285..1529ea5 100644 --- a/pkg/rp/default_rp.go +++ b/pkg/rp/default_rp.go @@ -64,7 +64,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc } if p.verifier == nil { - p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint + p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) } return p, nil @@ -110,6 +110,7 @@ func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc { //handling the oauth2 code exchange, extracting and validating the id_token //returning it paresed together with the oauth2 tokens (access, refresh) func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) { + ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient) token, err := p.oauthConfig.Exchange(ctx, code) if err != nil { return nil, err //TODO: our error @@ -124,7 +125,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc return nil, err //TODO: err } - return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil + return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil } //AuthURL is the `RelayingParty` interface implementation diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index 431c984..6d25411 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -443,7 +443,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor return nil //TODO: return error } - actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm) + actual, err := oidc.ClaimHash(accessToken, sigAlgorithm) if err != nil { return err } diff --git a/pkg/utils/http.go b/pkg/utils/http.go index b3d3434..6ad7083 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -55,3 +55,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e } return nil } + +func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) { + values := make(map[string][]string) + err := encoder.Encode(resp, values) + if err != nil { + return "", err + } + v := url.Values(values) + return v.Encode(), nil +}