feat: support for session_state (#712)
* add default signature algorithm * implements session_state in auth_request.go * add test * Update pkg/op/auth_request.go link to the standard Co-authored-by: Tim Möhlmann <muhlemmer@gmail.com> * add check_session_iframe --------- Co-authored-by: Tim Möhlmann <tim+github@zitadel.com> Co-authored-by: Tim Möhlmann <muhlemmer@gmail.com>
This commit is contained in:
parent
eb98343a65
commit
4ef9529012
9 changed files with 97 additions and 7 deletions
|
@ -164,6 +164,15 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthRequestWithSessionState struct {
|
||||||
|
*AuthRequest
|
||||||
|
SessionState string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestWithSessionState) GetSessionState() string {
|
||||||
|
return a.SessionState
|
||||||
|
}
|
||||||
|
|
||||||
type OIDCCodeChallenge struct {
|
type OIDCCodeChallenge struct {
|
||||||
Challenge string
|
Challenge string
|
||||||
Method string
|
Method string
|
||||||
|
|
|
@ -133,6 +133,7 @@ type Error struct {
|
||||||
ErrorType errorType `json:"error" schema:"error"`
|
ErrorType errorType `json:"error" schema:"error"`
|
||||||
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
|
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
|
||||||
State string `json:"state,omitempty" schema:"state,omitempty"`
|
State string `json:"state,omitempty" schema:"state,omitempty"`
|
||||||
|
SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"`
|
||||||
redirectDisabled bool `schema:"-"`
|
redirectDisabled bool `schema:"-"`
|
||||||
returnParent bool `schema:"-"`
|
returnParent bool `schema:"-"`
|
||||||
}
|
}
|
||||||
|
@ -142,11 +143,13 @@ func (e *Error) MarshalJSON() ([]byte, error) {
|
||||||
Error errorType `json:"error"`
|
Error errorType `json:"error"`
|
||||||
ErrorDescription string `json:"error_description,omitempty"`
|
ErrorDescription string `json:"error_description,omitempty"`
|
||||||
State string `json:"state,omitempty"`
|
State string `json:"state,omitempty"`
|
||||||
|
SessionState string `json:"session_state,omitempty"`
|
||||||
Parent string `json:"parent,omitempty"`
|
Parent string `json:"parent,omitempty"`
|
||||||
}{
|
}{
|
||||||
Error: e.ErrorType,
|
Error: e.ErrorType,
|
||||||
ErrorDescription: e.Description,
|
ErrorDescription: e.Description,
|
||||||
State: e.State,
|
State: e.State,
|
||||||
|
SessionState: e.SessionState,
|
||||||
}
|
}
|
||||||
if e.returnParent {
|
if e.returnParent {
|
||||||
m.Parent = e.Parent.Error()
|
m.Parent = e.Parent.Error()
|
||||||
|
@ -176,7 +179,8 @@ func (e *Error) Is(target error) bool {
|
||||||
}
|
}
|
||||||
return e.ErrorType == t.ErrorType &&
|
return e.ErrorType == t.ErrorType &&
|
||||||
(e.Description == t.Description || t.Description == "") &&
|
(e.Description == t.Description || t.Description == "") &&
|
||||||
(e.State == t.State || t.State == "")
|
(e.State == t.State || t.State == "") &&
|
||||||
|
(e.SessionState == t.SessionState || t.SessionState == "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Error) WithParent(err error) *Error {
|
func (e *Error) WithParent(err error) *Error {
|
||||||
|
@ -242,6 +246,9 @@ func (e *Error) LogValue() slog.Value {
|
||||||
if e.State != "" {
|
if e.State != "" {
|
||||||
attrs = append(attrs, slog.String("state", e.State))
|
attrs = append(attrs, slog.String("state", e.State))
|
||||||
}
|
}
|
||||||
|
if e.SessionState != "" {
|
||||||
|
attrs = append(attrs, slog.String("session_state", e.SessionState))
|
||||||
|
}
|
||||||
if e.redirectDisabled {
|
if e.redirectDisabled {
|
||||||
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
|
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,13 @@ type AuthRequest interface {
|
||||||
Done() bool
|
Done() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported
|
||||||
|
type AuthRequestSessionState interface {
|
||||||
|
// GetSessionState returns session_state.
|
||||||
|
// session_state is related to OpenID Connect Session Management.
|
||||||
|
GetSessionState() string
|
||||||
|
}
|
||||||
|
|
||||||
type Authorizer interface {
|
type Authorizer interface {
|
||||||
Storage() Storage
|
Storage() Storage
|
||||||
Decoder() httphelper.Decoder
|
Decoder() httphelper.Decoder
|
||||||
|
@ -103,8 +110,8 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
}
|
}
|
||||||
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
|
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
|
||||||
}
|
}
|
||||||
if validater, ok := authorizer.(AuthorizeValidator); ok {
|
if validator, ok := authorizer.(AuthorizeValidator); ok {
|
||||||
validation = validater.ValidateAuthRequest
|
validation = validator.ValidateAuthRequest
|
||||||
}
|
}
|
||||||
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
|
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -481,12 +488,19 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
|
||||||
AuthRequestError(w, r, authReq, err, authorizer)
|
AuthRequestError(w, r, authReq, err, authorizer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var sessionState string
|
||||||
|
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
|
||||||
|
if ok {
|
||||||
|
sessionState = authRequestSessionState.GetSessionState()
|
||||||
|
}
|
||||||
codeResponse := struct {
|
codeResponse := struct {
|
||||||
Code string `schema:"code"`
|
Code string `schema:"code"`
|
||||||
State string `schema:"state,omitempty"`
|
State string `schema:"state,omitempty"`
|
||||||
|
SessionState string `schema:"session_state,omitempty"`
|
||||||
}{
|
}{
|
||||||
Code: code,
|
Code: code,
|
||||||
State: authReq.GetState(),
|
State: authReq.GetState(),
|
||||||
|
SessionState: sessionState,
|
||||||
}
|
}
|
||||||
|
|
||||||
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
||||||
|
|
|
@ -1090,6 +1090,34 @@ func TestAuthResponseCode(t *testing.T) {
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "success with state and session_state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequestWithSessionState{
|
||||||
|
AuthRequest: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
SessionState: "session_state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1",
|
||||||
|
wantBody: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "success without state", // reproduce issue #415
|
name: "success without state", // reproduce issue #415
|
||||||
args: args{
|
args: args{
|
||||||
|
|
|
@ -30,6 +30,7 @@ type Configuration interface {
|
||||||
EndSessionEndpoint() *Endpoint
|
EndSessionEndpoint() *Endpoint
|
||||||
KeysEndpoint() *Endpoint
|
KeysEndpoint() *Endpoint
|
||||||
DeviceAuthorizationEndpoint() *Endpoint
|
DeviceAuthorizationEndpoint() *Endpoint
|
||||||
|
CheckSessionIframe() *Endpoint
|
||||||
|
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
CodeMethodS256Supported() bool
|
CodeMethodS256Supported() bool
|
||||||
|
|
|
@ -45,6 +45,7 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
|
||||||
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
|
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
|
||||||
JwksURI: config.KeysEndpoint().Absolute(issuer),
|
JwksURI: config.KeysEndpoint().Absolute(issuer),
|
||||||
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
|
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
|
||||||
|
CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer),
|
||||||
ScopesSupported: Scopes(config),
|
ScopesSupported: Scopes(config),
|
||||||
ResponseTypesSupported: ResponseTypes(config),
|
ResponseTypesSupported: ResponseTypes(config),
|
||||||
GrantTypesSupported: GrantTypes(config),
|
GrantTypesSupported: GrantTypes(config),
|
||||||
|
|
|
@ -46,6 +46,12 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
e.State = authReq.GetState()
|
e.State = authReq.GetState()
|
||||||
|
var sessionState string
|
||||||
|
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
|
||||||
|
if ok {
|
||||||
|
sessionState = authRequestSessionState.GetSessionState()
|
||||||
|
}
|
||||||
|
e.SessionState = sessionState
|
||||||
var responseMode oidc.ResponseMode
|
var responseMode oidc.ResponseMode
|
||||||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||||
responseMode = rm.GetResponseMode()
|
responseMode = rm.GetResponseMode()
|
||||||
|
@ -92,6 +98,12 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
|
||||||
}
|
}
|
||||||
|
|
||||||
e.State = authReq.GetState()
|
e.State = authReq.GetState()
|
||||||
|
var sessionState string
|
||||||
|
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
|
||||||
|
if ok {
|
||||||
|
sessionState = authRequestSessionState.GetSessionState()
|
||||||
|
}
|
||||||
|
e.SessionState = sessionState
|
||||||
var responseMode oidc.ResponseMode
|
var responseMode oidc.ResponseMode
|
||||||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||||
responseMode = rm.GetResponseMode()
|
responseMode = rm.GetResponseMode()
|
||||||
|
|
|
@ -106,6 +106,20 @@ func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Ca
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckSessionIframe mocks base method.
|
||||||
|
func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CheckSessionIframe")
|
||||||
|
ret0, _ := ret[0].(*op.Endpoint)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckSessionIframe indicates an expected call of CheckSessionIframe.
|
||||||
|
func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe))
|
||||||
|
}
|
||||||
|
|
||||||
// CodeMethodS256Supported mocks base method.
|
// CodeMethodS256Supported mocks base method.
|
||||||
func (m *MockConfiguration) CodeMethodS256Supported() bool {
|
func (m *MockConfiguration) CodeMethodS256Supported() bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -339,6 +339,10 @@ func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
|
||||||
return o.endpoints.DeviceAuthorization
|
return o.endpoints.DeviceAuthorization
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Provider) CheckSessionIframe() *Endpoint {
|
||||||
|
return o.endpoints.CheckSessionIframe
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Provider) KeysEndpoint() *Endpoint {
|
func (o *Provider) KeysEndpoint() *Endpoint {
|
||||||
return o.endpoints.JwksURI
|
return o.endpoints.JwksURI
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue