diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index d0336c1..a2fb779 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -2,6 +2,7 @@ package mock import ( "errors" + "time" "gopkg.in/square/go-jose.v2" @@ -32,6 +33,10 @@ func (a *AuthRequest) GetAudience() []string { } } +func (a *AuthRequest) GetAuthTime() time.Time { + return time.Now().UTC() +} + func (a *AuthRequest) GetClientID() string { return "" } diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 3b7ae81..6c5610c 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -163,7 +163,7 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri } } - idToken, err := CreateIDToken("", authReq, accessToken, time.Now(), time.Now(), "", authorizer.Signer()) + idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer()) if err != nil { } diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 8f1a325..55874da 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -2,6 +2,7 @@ package op import ( "net/http" + "time" "github.com/gorilla/schema" @@ -22,6 +23,7 @@ var ( IntrospectionEndpoint: defaultIntrospectEndpoint, Userinfo: defaultUserinfoEndpoint, } + DefaultIDTokenValidity = time.Duration(5 * time.Minute) ) type DefaultOP struct { @@ -36,7 +38,8 @@ type DefaultOP struct { } type Config struct { - Issuer string + Issuer string + IDTokenValidity time.Duration // ScopesSupported: oidc.SupportedScopes, // ResponseTypesSupported: responseTypes, // GrantTypesSupported: oidc.SupportedGrantTypes, @@ -172,6 +175,13 @@ func (p *DefaultOP) Signer() Signer { // return } +func (p *DefaultOP) IDTokenValidity() time.Duration { + if p.config.IDTokenValidity == 0 { + p.config.IDTokenValidity = DefaultIDTokenValidity + } + return p.config.IDTokenValidity +} + // func (p *DefaultOP) ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { // return AuthRequestError // } diff --git a/pkg/op/error.go b/pkg/op/error.go index 5c7da30..3a9fa2f 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -58,9 +58,11 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { e, ok := err.(*OAuthError) if !ok { + e = new(OAuthError) e.ErrorType = ServerError e.Description = err.Error() } + w.WriteHeader(http.StatusBadRequest) utils.MarshalJSON(w, e) } diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index f7dc58f..e90a11e 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -5,6 +5,7 @@ import ( "github.com/golang/mock/gomock" "github.com/gorilla/schema" + "gopkg.in/square/go-jose.v2" oidc "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" @@ -69,6 +70,9 @@ type Sig struct{} func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { return "", nil } +func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { + return jose.HS256 +} func ExpectStorage(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) diff --git a/pkg/op/signer.go b/pkg/op/signer.go index fd36652..39dfdc6 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -10,11 +10,13 @@ import ( type Signer interface { SignIDToken(claims *oidc.IDTokenClaims) (string, error) + SignatureAlgorithm() jose.SignatureAlgorithm } type idTokenSigner struct { - signer jose.Signer - storage Storage + signer jose.Signer + storage Storage + algorithm jose.SignatureAlgorithm } func NewDefaultSigner(storage Storage) (Signer, error) { @@ -36,6 +38,7 @@ func (s *idTokenSigner) initialize() error { if err != nil { return err } + s.algorithm = key.Algorithm return nil } @@ -46,6 +49,7 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) } return s.Sign(payload) } + func (s *idTokenSigner) Sign(payload []byte) (string, error) { result, err := s.signer.Sign(payload) if err != nil { @@ -53,3 +57,7 @@ func (s *idTokenSigner) Sign(payload []byte) (string, error) { } return result.CompactSerialize() } + +func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { + return s.algorithm +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 7db58d7..08f37e0 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -1,6 +1,8 @@ package op import ( + "time" + "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/oidc" @@ -22,6 +24,7 @@ type AuthRequest interface { GetACR() string GetAMR() []string GetAudience() []string + GetAuthTime() time.Time GetClientID() string GetNonce() string GetRedirectURI() string diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 9d41959..83b3cde 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -5,17 +5,15 @@ import ( "net/http" "time" - "gopkg.in/square/go-jose.v2" - - "github.com/caos/oidc/pkg/utils" - "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" ) type Exchanger interface { Issuer() string + IDTokenValidity() time.Duration Storage() Storage Decoder() *schema.Decoder Signer() Signer @@ -59,7 +57,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ExchangeRequestError(w, r, err) return } - idToken, err := CreateIDToken(exchanger.Issuer(), authReq, "", time.Now(), time.Now(), "", exchanger.Signer()) + idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer()) if err != nil { ExchangeRequestError(w, r, err) return @@ -76,23 +74,23 @@ func CreateAccessToken() (string, error) { return "accessToken", nil } -func CreateIDToken(issuer string, authReq AuthRequest, sub string, exp, authTime time.Time, accessToken string, signer Signer) (string, error) { +func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) { var err error + exp := time.Now().UTC().Add(validity) claims := &oidc.IDTokenClaims{ Issuer: issuer, Subject: authReq.GetSubject(), Audiences: authReq.GetAudience(), Expiration: exp, IssuedAt: time.Now().UTC(), - AuthTime: authTime, + AuthTime: authReq.GetAuthTime(), Nonce: authReq.GetNonce(), AuthenticationContextClassReference: authReq.GetACR(), AuthenticationMethodsReferences: authReq.GetAMR(), AuthorizedParty: authReq.GetClientID(), } if accessToken != "" { - var alg jose.SignatureAlgorithm - claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, alg) //TODO: alg + claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm()) if err != nil { return "", err } diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go index 0537443..dce6285 100644 --- a/pkg/rp/default_rp.go +++ b/pkg/rp/default_rp.go @@ -20,6 +20,12 @@ const ( stateParam = "state" ) +var ( + DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { + http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) + } +) + //DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface type DefaultRP struct { endpoints Endpoints @@ -30,6 +36,8 @@ type DefaultRP struct { httpClient *http.Client cookieHandler *utils.CookieHandler + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + verifier Verifier } @@ -51,6 +59,10 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc return nil, err } + if p.errorHandler == nil { + p.errorHandler = DefaultErrorHandler + } + if p.verifier == nil { p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint } @@ -125,15 +137,16 @@ func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http return } params := r.URL.Query() - if params.Get("code") != "" { - tokens, err := p.CodeExchange(r.Context(), params.Get("code")) - if err != nil { - http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) - return - } - callback(w, r, tokens, state) + if params.Get("error") != "" { + p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state) + return } - w.Write([]byte(params.Get("error"))) + tokens, err := p.CodeExchange(r.Context(), params.Get("code")) + if err != nil { + http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + return + } + callback(w, r, tokens, state) } } @@ -169,18 +182,15 @@ func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken st func (p *DefaultRP) discover() error { wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint - req, err := http.NewRequest("GET", wellKnown, nil) if err != nil { return err } discoveryConfig := new(oidc.DiscoveryConfiguration) - err = utils.HttpRequest(p.httpClient, req, &discoveryConfig) if err != nil { return err } - p.endpoints = GetEndpoints(discoveryConfig) p.oauthConfig = oauth2.Config{ ClientID: p.config.ClientID,