feat: support for session_state (#712)

* add default signature algorithm

* implements session_state in auth_request.go

* add test

* Update pkg/op/auth_request.go

link to the standard

Co-authored-by: Tim Möhlmann <muhlemmer@gmail.com>

* add check_session_iframe

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Tim Möhlmann <muhlemmer@gmail.com>
This commit is contained in:
minami yoshihiko 2025-02-24 19:50:38 +09:00 committed by GitHub
parent eb98343a65
commit 4ef9529012
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 97 additions and 7 deletions

View file

@ -164,6 +164,15 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques
} }
} }
type AuthRequestWithSessionState struct {
*AuthRequest
SessionState string
}
func (a *AuthRequestWithSessionState) GetSessionState() string {
return a.SessionState
}
type OIDCCodeChallenge struct { type OIDCCodeChallenge struct {
Challenge string Challenge string
Method string Method string

View file

@ -133,6 +133,7 @@ type Error struct {
ErrorType errorType `json:"error" schema:"error"` ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"` State string `json:"state,omitempty" schema:"state,omitempty"`
SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"`
redirectDisabled bool `schema:"-"` redirectDisabled bool `schema:"-"`
returnParent bool `schema:"-"` returnParent bool `schema:"-"`
} }
@ -142,11 +143,13 @@ func (e *Error) MarshalJSON() ([]byte, error) {
Error errorType `json:"error"` Error errorType `json:"error"`
ErrorDescription string `json:"error_description,omitempty"` ErrorDescription string `json:"error_description,omitempty"`
State string `json:"state,omitempty"` State string `json:"state,omitempty"`
SessionState string `json:"session_state,omitempty"`
Parent string `json:"parent,omitempty"` Parent string `json:"parent,omitempty"`
}{ }{
Error: e.ErrorType, Error: e.ErrorType,
ErrorDescription: e.Description, ErrorDescription: e.Description,
State: e.State, State: e.State,
SessionState: e.SessionState,
} }
if e.returnParent { if e.returnParent {
m.Parent = e.Parent.Error() m.Parent = e.Parent.Error()
@ -176,7 +179,8 @@ func (e *Error) Is(target error) bool {
} }
return e.ErrorType == t.ErrorType && return e.ErrorType == t.ErrorType &&
(e.Description == t.Description || t.Description == "") && (e.Description == t.Description || t.Description == "") &&
(e.State == t.State || t.State == "") (e.State == t.State || t.State == "") &&
(e.SessionState == t.SessionState || t.SessionState == "")
} }
func (e *Error) WithParent(err error) *Error { func (e *Error) WithParent(err error) *Error {
@ -242,6 +246,9 @@ func (e *Error) LogValue() slog.Value {
if e.State != "" { if e.State != "" {
attrs = append(attrs, slog.String("state", e.State)) attrs = append(attrs, slog.String("state", e.State))
} }
if e.SessionState != "" {
attrs = append(attrs, slog.String("session_state", e.SessionState))
}
if e.redirectDisabled { if e.redirectDisabled {
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
} }

View file

@ -38,6 +38,13 @@ type AuthRequest interface {
Done() bool Done() bool
} }
// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported
type AuthRequestSessionState interface {
// GetSessionState returns session_state.
// session_state is related to OpenID Connect Session Management.
GetSessionState() string
}
type Authorizer interface { type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
@ -103,8 +110,8 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
} }
return ValidateAuthRequestClient(ctx, authReq, client, verifier) return ValidateAuthRequestClient(ctx, authReq, client, verifier)
} }
if validater, ok := authorizer.(AuthorizeValidator); ok { if validator, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest validation = validator.ValidateAuthRequest
} }
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
@ -481,12 +488,19 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
AuthRequestError(w, r, authReq, err, authorizer) AuthRequestError(w, r, authReq, err, authorizer)
return return
} }
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
codeResponse := struct { codeResponse := struct {
Code string `schema:"code"` Code string `schema:"code"`
State string `schema:"state,omitempty"` State string `schema:"state,omitempty"`
SessionState string `schema:"session_state,omitempty"`
}{ }{
Code: code, Code: code,
State: authReq.GetState(), State: authReq.GetState(),
SessionState: sessionState,
} }
if authReq.GetResponseMode() == oidc.ResponseModeFormPost { if authReq.GetResponseMode() == oidc.ResponseModeFormPost {

View file

@ -1090,6 +1090,34 @@ func TestAuthResponseCode(t *testing.T) {
wantBody: "", wantBody: "",
}, },
}, },
{
name: "success with state and session_state",
args: args{
authReq: &storage.AuthRequestWithSessionState{
AuthRequest: &storage.AuthRequest{
ID: "id1",
TransferState: "state1",
},
SessionState: "session_state1",
},
authorizer: func(t *testing.T) op.Authorizer {
ctrl := gomock.NewController(t)
storage := mock.NewMockStorage(ctrl)
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
authorizer := mock.NewMockAuthorizer(ctrl)
authorizer.EXPECT().Storage().Return(storage)
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
return authorizer
},
},
res: res{
wantCode: http.StatusFound,
wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1",
wantBody: "",
},
},
{ {
name: "success without state", // reproduce issue #415 name: "success without state", // reproduce issue #415
args: args{ args: args{

View file

@ -30,6 +30,7 @@ type Configuration interface {
EndSessionEndpoint() *Endpoint EndSessionEndpoint() *Endpoint
KeysEndpoint() *Endpoint KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() *Endpoint DeviceAuthorizationEndpoint() *Endpoint
CheckSessionIframe() *Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool

View file

@ -45,6 +45,7 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer), DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer),
ScopesSupported: Scopes(config), ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config), ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config), GrantTypesSupported: GrantTypes(config),

View file

@ -46,6 +46,12 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
return return
} }
e.State = authReq.GetState() e.State = authReq.GetState()
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
e.SessionState = sessionState
var responseMode oidc.ResponseMode var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode() responseMode = rm.GetResponseMode()
@ -92,6 +98,12 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
} }
e.State = authReq.GetState() e.State = authReq.GetState()
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
e.SessionState = sessionState
var responseMode oidc.ResponseMode var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode() responseMode = rm.GetResponseMode()

View file

@ -106,6 +106,20 @@ func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported))
} }
// CheckSessionIframe mocks base method.
func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckSessionIframe")
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
// CheckSessionIframe indicates an expected call of CheckSessionIframe.
func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe))
}
// CodeMethodS256Supported mocks base method. // CodeMethodS256Supported mocks base method.
func (m *MockConfiguration) CodeMethodS256Supported() bool { func (m *MockConfiguration) CodeMethodS256Supported() bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -339,6 +339,10 @@ func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization return o.endpoints.DeviceAuthorization
} }
func (o *Provider) CheckSessionIframe() *Endpoint {
return o.endpoints.CheckSessionIframe
}
func (o *Provider) KeysEndpoint() *Endpoint { func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI return o.endpoints.JwksURI
} }