From 32ff570f7c3cf6e34a6bad7894b8fae6b2a91ef9 Mon Sep 17 00:00:00 2001 From: "minami.yoshihiko" Date: Sat, 15 Feb 2025 17:58:14 +0900 Subject: [PATCH] implements session_state in auth_request.go --- pkg/oidc/error.go | 9 ++++++++- pkg/op/auth_request.go | 26 ++++++++++++++++++++------ pkg/op/error.go | 12 ++++++++++++ 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 1100f73..d93cf44 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -133,6 +133,7 @@ type Error struct { ErrorType errorType `json:"error" schema:"error"` Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` State string `json:"state,omitempty" schema:"state,omitempty"` + SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"` redirectDisabled bool `schema:"-"` returnParent bool `schema:"-"` } @@ -142,11 +143,13 @@ func (e *Error) MarshalJSON() ([]byte, error) { Error errorType `json:"error"` ErrorDescription string `json:"error_description,omitempty"` State string `json:"state,omitempty"` + SessionState string `json:"session_state,omitempty"` Parent string `json:"parent,omitempty"` }{ Error: e.ErrorType, ErrorDescription: e.Description, State: e.State, + SessionState: e.SessionState, } if e.returnParent { m.Parent = e.Parent.Error() @@ -176,7 +179,8 @@ func (e *Error) Is(target error) bool { } return e.ErrorType == t.ErrorType && (e.Description == t.Description || t.Description == "") && - (e.State == t.State || t.State == "") + (e.State == t.State || t.State == "") && + (e.SessionState == t.SessionState || t.SessionState == "") } func (e *Error) WithParent(err error) *Error { @@ -242,6 +246,9 @@ func (e *Error) LogValue() slog.Value { if e.State != "" { attrs = append(attrs, slog.String("state", e.State)) } + if e.SessionState != "" { + attrs = append(attrs, slog.String("session_state", e.SessionState)) + } if e.redirectDisabled { attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index d6db62b..e77502b 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -38,6 +38,13 @@ type AuthRequest interface { Done() bool } +// AuthRequestSessionState should be implemented if OpenID Connect Session Management is supported +type AuthRequestSessionState interface { + // GetSessionState returns session_state. + // session_state is related to OpenID Connect Session Management. + GetSessionState() string +} + type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder @@ -103,8 +110,8 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } return ValidateAuthRequestClient(ctx, authReq, client, verifier) } - if validater, ok := authorizer.(AuthorizeValidator); ok { - validation = validater.ValidateAuthRequest + if validator, ok := authorizer.(AuthorizeValidator); ok { + validation = validator.ValidateAuthRequest } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { @@ -481,12 +488,19 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques AuthRequestError(w, r, authReq, err, authorizer) return } + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } codeResponse := struct { - Code string `schema:"code"` - State string `schema:"state,omitempty"` + Code string `schema:"code"` + State string `schema:"state,omitempty"` + SessionState string `schema:"session_state,omitempty"` }{ - Code: code, - State: authReq.GetState(), + Code: code, + State: authReq.GetState(), + SessionState: sessionState, } if authReq.GetResponseMode() == oidc.ResponseModeFormPost { diff --git a/pkg/op/error.go b/pkg/op/error.go index 44b1798..d57da83 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -46,6 +46,12 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq return } e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState var responseMode oidc.ResponseMode if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() @@ -92,6 +98,12 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, } e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState var responseMode oidc.ResponseMode if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode()