fix: parse max_age and prompt correctly (and change scope type) (#105)

* fix: parse max_age and prompt correctly (and change scope type)

* remove unnecessary omitempty
This commit is contained in:
Livio Amstutz 2021-06-16 08:34:01 +02:00 committed by GitHub
parent 0591a0d1ef
commit 400f5c4de4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 98 additions and 85 deletions

View file

@ -95,7 +95,7 @@ func (a *AuthRequest) GetScopes() []string {
} }
} }
func (a *AuthRequest) SetCurrentScopes(scopes oidc.Scopes) {} func (a *AuthRequest) SetCurrentScopes(scopes []string) {}
func (a *AuthRequest) GetState() string { func (a *AuthRequest) GetState() string {
return "" return ""
@ -243,7 +243,7 @@ func (s *AuthStorage) SetIntrospectionFromToken(ctx context.Context, introspect
return nil return nil
} }
func (s *AuthStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scope oidc.Scopes) (oidc.Scopes, error) { func (s *AuthStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scope []string) ([]string, error) {
return scope, nil return scope, nil
} }

View file

@ -17,8 +17,8 @@ import (
var ( var (
Encoder = func() utils.Encoder { Encoder = func() utils.Encoder {
e := schema.NewEncoder() e := schema.NewEncoder()
e.RegisterEncoder(oidc.Scopes{}, func(value reflect.Value) string { e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(oidc.Scopes).Encode() return value.Interface().(oidc.SpaceDelimitedArray).Encode()
}) })
return e return e
}() }()

View file

@ -430,10 +430,10 @@ func WithCodeChallenge(codeChallenge string) AuthURLOpt {
} }
//WithPrompt sets the `prompt` params in the auth request //WithPrompt sets the `prompt` params in the auth request
func WithPrompt(prompt oidc.Prompt) AuthURLOpt { func WithPrompt(prompt ...string) AuthURLOpt {
return func() []oauth2.AuthCodeOption { return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{ return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("prompt", string(prompt)), oauth2.SetAuthURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()),
} }
} }
} }

View file

@ -44,39 +44,39 @@ const (
//PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages. //PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages.
//An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed //An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed
PromptNone Prompt = "none" PromptNone = "none"
//PromptLogin (`login`) directs the Authorization Server to prompt the End-User for reauthentication. //PromptLogin (`login`) directs the Authorization Server to prompt the End-User for reauthentication.
PromptLogin Prompt = "login" PromptLogin = "login"
//PromptConsent (`consent`) directs the Authorization Server to prompt the End-User for consent (of sharing information). //PromptConsent (`consent`) directs the Authorization Server to prompt the End-User for consent (of sharing information).
PromptConsent Prompt = "consent" PromptConsent = "consent"
//PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching) //PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
PromptSelectAccount Prompt = "select_account" PromptSelectAccount = "select_account"
) )
//AuthRequest according to: //AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct { type AuthRequest struct {
ID string ID string
Scopes Scopes `schema:"scope"` Scopes SpaceDelimitedArray `schema:"scope"`
ResponseType ResponseType `schema:"response_type"` ResponseType ResponseType `schema:"response_type"`
ClientID string `schema:"client_id"` ClientID string `schema:"client_id"`
RedirectURI string `schema:"redirect_uri"` //TODO: type RedirectURI string `schema:"redirect_uri"` //TODO: type
State string `schema:"state"` State string `schema:"state"`
// ResponseMode TODO: ? // ResponseMode TODO: ?
Nonce string `schema:"nonce"` Nonce string `schema:"nonce"`
Display Display `schema:"display"` Display Display `schema:"display"`
Prompt Prompt `schema:"prompt"` Prompt SpaceDelimitedArray `schema:"prompt"`
MaxAge uint32 `schema:"max_age"` MaxAge *uint `schema:"max_age"`
UILocales Locales `schema:"ui_locales"` UILocales Locales `schema:"ui_locales"`
IDTokenHint string `schema:"id_token_hint"` IDTokenHint string `schema:"id_token_hint"`
LoginHint string `schema:"login_hint"` LoginHint string `schema:"login_hint"`
ACRValues []string `schema:"acr_values"` ACRValues []string `schema:"acr_values"`
CodeChallenge string `schema:"code_challenge"` CodeChallenge string `schema:"code_challenge"`
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`

View file

@ -21,7 +21,7 @@ type IntrospectionResponse interface {
UserInfoSetter UserInfoSetter
SetActive(bool) SetActive(bool)
IsActive() bool IsActive() bool
SetScopes(scopes Scopes) SetScopes(scopes []string)
SetClientID(id string) SetClientID(id string)
} }
@ -30,10 +30,10 @@ func NewIntrospectionResponse() IntrospectionResponse {
} }
type introspectionResponse struct { type introspectionResponse struct {
Active bool `json:"active"` Active bool `json:"active"`
Scope Scopes `json:"scope,omitempty"` Scope SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"` ClientID string `json:"client_id,omitempty"`
Subject string `json:"sub,omitempty"` Subject string `json:"sub,omitempty"`
userInfoProfile userInfoProfile
userInfoEmail userInfoEmail
userInfoPhone userInfoPhone
@ -46,7 +46,7 @@ func (u *introspectionResponse) IsActive() bool {
return u.Active return u.Active
} }
func (u *introspectionResponse) SetScopes(scope Scopes) { func (u *introspectionResponse) SetScopes(scope []string) {
u.Scope = scope u.Scope = scope
} }

View file

@ -1,9 +1,9 @@
package oidc package oidc
type JWTProfileGrantRequest struct { type JWTProfileGrantRequest struct {
Assertion string `schema:"assertion"` Assertion string `schema:"assertion"`
Scope Scopes `schema:"scope"` Scope SpaceDelimitedArray `schema:"scope"`
GrantType GrantType `schema:"grant_type"` GrantType GrantType `schema:"grant_type"`
} }
//NewJWTProfileGrantRequest creates an oauth2 `JSON Web Token (JWT) Profile` Grant //NewJWTProfileGrantRequest creates an oauth2 `JSON Web Token (JWT) Profile` Grant

View file

@ -58,12 +58,12 @@ func (a *AccessTokenRequest) SetClientSecret(clientSecret string) {
} }
type RefreshTokenRequest struct { type RefreshTokenRequest struct {
RefreshToken string `schema:"refresh_token"` RefreshToken string `schema:"refresh_token"`
Scopes Scopes `schema:"scope"` Scopes SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"` ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"` ClientSecret string `schema:"client_secret"`
ClientAssertion string `schema:"client_assertion"` ClientAssertion string `schema:"client_assertion"`
ClientAssertionType string `schema:"client_assertion_type"` ClientAssertionType string `schema:"client_assertion_type"`
} }
func (a *RefreshTokenRequest) GrantType() GrantType { func (a *RefreshTokenRequest) GrantType() GrantType {
@ -81,12 +81,12 @@ func (a *RefreshTokenRequest) SetClientSecret(clientSecret string) {
} }
type JWTTokenRequest struct { type JWTTokenRequest struct {
Issuer string `json:"iss"` Issuer string `json:"iss"`
Subject string `json:"sub"` Subject string `json:"sub"`
Scopes Scopes `json:"-"` Scopes SpaceDelimitedArray `json:"-"`
Audience Audience `json:"aud"` Audience Audience `json:"aud"`
IssuedAt Time `json:"iat"` IssuedAt Time `json:"iat"`
ExpiresAt Time `json:"exp"` ExpiresAt Time `json:"exp"`
} }
//GetIssuer implements the Claims interface //GetIssuer implements the Claims interface
@ -143,12 +143,12 @@ func (j *JWTTokenRequest) GetScopes() []string {
} }
type TokenExchangeRequest struct { type TokenExchangeRequest struct {
subjectToken string `schema:"subject_token"` subjectToken string `schema:"subject_token"`
subjectTokenType string `schema:"subject_token_type"` subjectTokenType string `schema:"subject_token_type"`
actorToken string `schema:"actor_token"` actorToken string `schema:"actor_token"`
actorTokenType string `schema:"actor_token_type"` actorTokenType string `schema:"actor_token_type"`
resource []string `schema:"resource"` resource []string `schema:"resource"`
audience Audience `schema:"audience"` audience Audience `schema:"audience"`
Scope Scopes `schema:"scope"` Scope SpaceDelimitedArray `schema:"scope"`
requestedTokenType string `schema:"requested_token_type"` requestedTokenType string `schema:"requested_token_type"`
} }

View file

@ -54,30 +54,36 @@ func (l *Locales) UnmarshalText(text []byte) error {
return nil return nil
} }
type Prompt string type MaxAge *uint
func NewMaxAge(i uint) MaxAge {
return &i
}
type SpaceDelimitedArray []string
type Prompt SpaceDelimitedArray
type ResponseType string type ResponseType string
type Scopes []string func (s SpaceDelimitedArray) Encode() string {
func (s Scopes) Encode() string {
return strings.Join(s, " ") return strings.Join(s, " ")
} }
func (s *Scopes) UnmarshalText(text []byte) error { func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
*s = strings.Split(string(text), " ") *s = strings.Split(string(text), " ")
return nil return nil
} }
func (s *Scopes) MarshalText() ([]byte, error) { func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
return []byte(s.Encode()), nil return []byte(s.Encode()), nil
} }
func (s *Scopes) MarshalJSON() ([]byte, error) { func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
return json.Marshal((*s).Encode()) return json.Marshal((s).Encode())
} }
func (s *Scopes) UnmarshalJSON(data []byte) error { func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
var str string var str string
if err := json.Unmarshal(data, &str); err != nil { if err := json.Unmarshal(data, &str); err != nil {
return err return err

View file

@ -220,7 +220,7 @@ func TestScopes_UnmarshalText(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
var scopes Scopes var scopes SpaceDelimitedArray
if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr { if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
} }
@ -230,7 +230,7 @@ func TestScopes_UnmarshalText(t *testing.T) {
} }
func TestScopes_MarshalText(t *testing.T) { func TestScopes_MarshalText(t *testing.T) {
type args struct { type args struct {
scopes Scopes scopes SpaceDelimitedArray
} }
type res struct { type res struct {
scopes []byte scopes []byte
@ -244,7 +244,7 @@ func TestScopes_MarshalText(t *testing.T) {
{ {
"unknown value", "unknown value",
args{ args{
Scopes{"unknown"}, SpaceDelimitedArray{"unknown"},
}, },
res{ res{
[]byte("unknown"), []byte("unknown"),
@ -254,7 +254,7 @@ func TestScopes_MarshalText(t *testing.T) {
{ {
"struct", "struct",
args{ args{
Scopes{`{"unknown":"value"}`}, SpaceDelimitedArray{`{"unknown":"value"}`},
}, },
res{ res{
[]byte(`{"unknown":"value"}`), []byte(`{"unknown":"value"}`),
@ -264,7 +264,7 @@ func TestScopes_MarshalText(t *testing.T) {
{ {
"openid", "openid",
args{ args{
Scopes{"openid"}, SpaceDelimitedArray{"openid"},
}, },
res{ res{
[]byte("openid"), []byte("openid"),
@ -274,7 +274,7 @@ func TestScopes_MarshalText(t *testing.T) {
{ {
"multiple scopes", "multiple scopes",
args{ args{
Scopes{"openid", "email", "custom:scope"}, SpaceDelimitedArray{"openid", "email", "custom:scope"},
}, },
res{ res{
[]byte("openid email custom:scope"), []byte("openid email custom:scope"),

View file

@ -106,7 +106,11 @@ func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRe
} }
//ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed //ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (string, error) { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return "", err
}
client, err := storage.GetClientByClientID(ctx, authReq.ClientID) client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", ErrServerError(err.Error())
@ -124,6 +128,19 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier)
} }
//ValidateAuthReqPrompt validates the passed prompt values and sets max_age to 0 if prompt login is present
func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
for _, prompt := range prompts {
if prompt == oidc.PromptNone && len(prompts) > 1 {
return nil, ErrInvalidRequest("The prompt parameter `none` must only be used as a single value")
}
if prompt == oidc.PromptLogin {
maxAge = oidc.NewMaxAge(0)
}
}
return maxAge, nil
}
//ValidateAuthReqScopes validates the passed scopes //ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 { if len(scopes) == 0 {

View file

@ -123,7 +123,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
}(), }(),
}, },
res{ res{
&oidc.AuthRequest{Scopes: oidc.Scopes{"openid"}}, &oidc.AuthRequest{Scopes: oidc.SpaceDelimitedArray{"openid"}},
false, false,
}, },
}, },

View file

@ -316,10 +316,10 @@ func (mr *MockStorageMockRecorder) TokenRequestByRefreshToken(arg0, arg1 interfa
} }
// ValidateJWTProfileScopes mocks base method. // ValidateJWTProfileScopes mocks base method.
func (m *MockStorage) ValidateJWTProfileScopes(arg0 context.Context, arg1 string, arg2 oidc.Scopes) (oidc.Scopes, error) { func (m *MockStorage) ValidateJWTProfileScopes(arg0 context.Context, arg1 string, arg2 []string) ([]string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateJWTProfileScopes", arg0, arg1, arg2) ret := m.ctrl.Call(m, "ValidateJWTProfileScopes", arg0, arg1, arg2)
ret0, _ := ret[0].(oidc.Scopes) ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View file

@ -140,10 +140,10 @@ func (c *ConfClient) GetID() string {
} }
func (c *ConfClient) AccessTokenLifetime() time.Duration { func (c *ConfClient) AccessTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute) return 5 * time.Minute
} }
func (c *ConfClient) IDTokenLifetime() time.Duration { func (c *ConfClient) IDTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute) return 5 * time.Minute
} }
func (c *ConfClient) AccessTokenType() op.AccessTokenType { func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType return c.accessTokenType

View file

@ -83,13 +83,3 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
} }
return nil, ErrInvalidRequest("post_logout_redirect_uri invalid") return nil, ErrInvalidRequest("post_logout_redirect_uri invalid")
} }
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
if authRequest == nil {
return true
}
if authRequest.Prompt == oidc.PromptNone {
return true
}
return false
}

View file

@ -34,7 +34,7 @@ type OPStorage interface {
SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) GetKeyByIDAndUserID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
ValidateJWTProfileScopes(ctx context.Context, userID string, scope oidc.Scopes) (oidc.Scopes, error) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)
} }
type Storage interface { type Storage interface {

View file

@ -17,7 +17,7 @@ type RefreshTokenRequest interface {
GetClientID() string GetClientID() string
GetScopes() []string GetScopes() []string
GetSubject() string GetSubject() string
SetCurrentScopes(scopes oidc.Scopes) SetCurrentScopes(scopes []string)
} }
//RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including //RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including
@ -72,7 +72,7 @@ func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshToke
//ValidateRefreshTokenScopes validates that the requested scope is a subset of the original auth request scope //ValidateRefreshTokenScopes validates that the requested scope is a subset of the original auth request scope
//it will set the requested scopes as current scopes onto RefreshTokenRequest //it will set the requested scopes as current scopes onto RefreshTokenRequest
//if empty the original scopes will be used //if empty the original scopes will be used
func ValidateRefreshTokenScopes(requestedScopes oidc.Scopes, authRequest RefreshTokenRequest) error { func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTokenRequest) error {
if len(requestedScopes) == 0 { if len(requestedScopes) == 0 {
return nil return nil
} }