From be04244212e4d8802c5d41398f03511c9f51d41c Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 11 May 2021 10:26:25 +0200 Subject: [PATCH] amr and scopes --- pkg/op/auth_request.go | 2 +- pkg/op/token.go | 27 ++++++++++++--------------- pkg/op/token_code.go | 2 +- pkg/op/token_refresh.go | 39 ++++++++++++++++++++++++++++----------- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index dcb5e39..09c633f 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -263,7 +263,7 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques //AuthResponseToken creates the successful token(s) authentication response func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly - resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "") + resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return diff --git a/pkg/op/token.go b/pkg/op/token.go index 8592003..c18aa30 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -21,12 +21,12 @@ type TokenRequest interface { GetScopes() []string } -func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) { - var accessToken, refreshToken string +func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) { + var accessToken, newRefreshToken string var validity time.Duration if createAccessToken { var err error - accessToken, refreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client) + accessToken, newRefreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken) if err != nil { return nil, err } @@ -47,18 +47,18 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli return &oidc.AccessTokenResponse{ AccessToken: accessToken, IDToken: idToken, - RefreshToken: refreshToken, + RefreshToken: newRefreshToken, TokenType: oidc.BearerToken, ExpiresIn: exp, }, nil } -func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage) (id, refreshToken string, exp time.Time, err error) { +func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string) (id, newRefreshToken string, exp time.Time, err error) { if needsRefreshToken(tokenRequest) { - id, exp, err = storage.CreateToken(ctx, tokenRequest) - return + return storage.CreateTokens(ctx, tokenRequest, refreshToken) } - return storage.CreateTokens(ctx, tokenRequest, "hodor") + id, exp, err = storage.CreateToken(ctx, tokenRequest) + return } func needsRefreshToken(tokenRequest TokenRequest) bool { @@ -72,8 +72,8 @@ func needsRefreshToken(tokenRequest TokenRequest) bool { } } -func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client) (accessToken, refreshToken string, validity time.Duration, err error) { - id, refreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage()) +func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client Client, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) { + id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken) if err != nil { return "", "", 0, err } @@ -108,8 +108,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex } type IDTokenRequest interface { - //GetACR() string - //GetAMR() []string + GetAMR() []string GetAudience() []string GetAuthTime() time.Time GetClientID() string @@ -120,13 +119,11 @@ type IDTokenRequest interface { func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, client Client) (string, error) { exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity) var acr, nonce string - var amr []string if authRequest, ok := request.(AuthRequest); ok { acr = authRequest.GetACR() - amr = authRequest.GetAMR() nonce = authRequest.GetNonce() } - claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, amr, request.GetClientID(), client.ClockSkew()) + claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew()) scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes()) if accessToken != "" { atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 0f27104..953181d 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -25,7 +25,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { RequestError(w, r, err) return } - resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code) + resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { RequestError(w, r, err) return diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index 5f3c3a5..7a8632a 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -11,15 +11,13 @@ import ( ) type RefreshTokenRequest interface { - //GetID() string - //GetACR() string - //GetAMR() []string + GetAMR() []string GetAudience() []string GetAuthTime() time.Time GetClientID() string GetScopes() []string GetSubject() string - //GetRefreshToken() string + SetCurrentScopes(scopes oidc.Scopes) } //RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including @@ -29,12 +27,12 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch if err != nil { RequestError(w, r, err) } - authReq, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) + validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { RequestError(w, r, err) return } - resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, "") + resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { RequestError(w, r, err) return @@ -58,14 +56,33 @@ func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshToke if tokenReq.RefreshToken == "" { return nil, nil, ErrInvalidRequest("code missing") } - authReq, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger) + request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger) if err != nil { return nil, nil, err } - if client.GetID() != authReq.GetClientID() { + if client.GetID() != request.GetClientID() { return nil, nil, ErrInvalidRequest("invalid auth code") } - return authReq, client, nil + if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil { + return nil, nil, err + } + return request, client, nil +} + +//ValidateRefreshTokenScopes validates that requested scope is a subset of the original auth request scope +//it will set the requested scopes as current scopes onto RefreshTokenRequest +//if empty the original scopes will be used +func ValidateRefreshTokenScopes(requestedScopes oidc.Scopes, authRequest RefreshTokenRequest) error { + if len(requestedScopes) == 0 { + return nil + } + for _, scope := range requestedScopes { + if !utils.Contains(authRequest.GetScopes(), scope) { + return errors.New("invalid_scope") + } + } + authRequest.SetCurrentScopes(requestedScopes) + return nil } //AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered. @@ -105,9 +122,9 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ //RefreshTokenRequestByRefreshToken returns the RefreshTokenRequest (data representing the original auth request) //corresponding to the refresh_token from Storage or an error func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { - authReq, err := storage.RefreshTokenRequestByRefreshToken(ctx, refreshToken) + request, err := storage.RefreshTokenRequestByRefreshToken(ctx, refreshToken) if err != nil { return nil, ErrInvalidRequest("invalid refreshToken") } - return authReq, nil + return request, nil }