diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index 22c0295..c04877f 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -164,6 +164,15 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques } } +type AuthRequestWithSessionState struct { + *AuthRequest + SessionState string +} + +func (a *AuthRequestWithSessionState) GetSessionState() string { + return a.SessionState +} + type OIDCCodeChallenge struct { Challenge string Method string diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 1100f73..d93cf44 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -133,6 +133,7 @@ type Error struct { ErrorType errorType `json:"error" schema:"error"` Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` State string `json:"state,omitempty" schema:"state,omitempty"` + SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"` redirectDisabled bool `schema:"-"` returnParent bool `schema:"-"` } @@ -142,11 +143,13 @@ func (e *Error) MarshalJSON() ([]byte, error) { Error errorType `json:"error"` ErrorDescription string `json:"error_description,omitempty"` State string `json:"state,omitempty"` + SessionState string `json:"session_state,omitempty"` Parent string `json:"parent,omitempty"` }{ Error: e.ErrorType, ErrorDescription: e.Description, State: e.State, + SessionState: e.SessionState, } if e.returnParent { m.Parent = e.Parent.Error() @@ -176,7 +179,8 @@ func (e *Error) Is(target error) bool { } return e.ErrorType == t.ErrorType && (e.Description == t.Description || t.Description == "") && - (e.State == t.State || t.State == "") + (e.State == t.State || t.State == "") && + (e.SessionState == t.SessionState || t.SessionState == "") } func (e *Error) WithParent(err error) *Error { @@ -242,6 +246,9 @@ func (e *Error) LogValue() slog.Value { if e.State != "" { attrs = append(attrs, slog.String("state", e.State)) } + if e.SessionState != "" { + attrs = append(attrs, slog.String("session_state", e.SessionState)) + } if e.redirectDisabled { attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index d6db62b..82f1b58 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -38,6 +38,13 @@ type AuthRequest interface { Done() bool } +// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported +type AuthRequestSessionState interface { + // GetSessionState returns session_state. + // session_state is related to OpenID Connect Session Management. + GetSessionState() string +} + type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder @@ -103,8 +110,8 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { } return ValidateAuthRequestClient(ctx, authReq, client, verifier) } - if validater, ok := authorizer.(AuthorizeValidator); ok { - validation = validater.ValidateAuthRequest + if validator, ok := authorizer.(AuthorizeValidator); ok { + validation = validator.ValidateAuthRequest } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { @@ -481,12 +488,19 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques AuthRequestError(w, r, authReq, err, authorizer) return } + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } codeResponse := struct { - Code string `schema:"code"` - State string `schema:"state,omitempty"` + Code string `schema:"code"` + State string `schema:"state,omitempty"` + SessionState string `schema:"session_state,omitempty"` }{ - Code: code, - State: authReq.GetState(), + Code: code, + State: authReq.GetState(), + SessionState: sessionState, } if authReq.GetResponseMode() == oidc.ResponseModeFormPost { diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 765e602..4878f5e 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -1090,6 +1090,34 @@ func TestAuthResponseCode(t *testing.T) { wantBody: "", }, }, + { + name: "success with state and session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1", + wantBody: "", + }, + }, { name: "success without state", // reproduce issue #415 args: args{ diff --git a/pkg/op/config.go b/pkg/op/config.go index 2fcede0..b271765 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -30,6 +30,7 @@ type Configuration interface { EndSessionEndpoint() *Endpoint KeysEndpoint() *Endpoint DeviceAuthorizationEndpoint() *Endpoint + CheckSessionIframe() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index e30a5a4..7aa7cf7 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -45,6 +45,7 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer), DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer), + CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer), ScopesSupported: Scopes(config), ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), diff --git a/pkg/op/error.go b/pkg/op/error.go index 44b1798..d57da83 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -46,6 +46,12 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq return } e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState var responseMode oidc.ResponseMode if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() @@ -92,6 +98,12 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, } e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState var responseMode oidc.ResponseMode if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 137c09d..0ef9d92 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -106,6 +106,20 @@ func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported)) } +// CheckSessionIframe mocks base method. +func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckSessionIframe") + ret0, _ := ret[0].(*op.Endpoint) + return ret0 +} + +// CheckSessionIframe indicates an expected call of CheckSessionIframe. +func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe)) +} + // CodeMethodS256Supported mocks base method. func (m *MockConfiguration) CodeMethodS256Supported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 190c2c4..58ae838 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -339,6 +339,10 @@ func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } +func (o *Provider) CheckSessionIframe() *Endpoint { + return o.endpoints.CheckSessionIframe +} + func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI }