feat: add CanTerminateSessionFromRequest interface

This commit is contained in:
Livio Spring 2023-07-17 16:32:12 +02:00
parent 4c844da05e
commit f4660b6b57
No known key found for this signature in database
GPG key ID: 26BB1C2FA5952CF0
2 changed files with 21 additions and 6 deletions

View file

@ -34,12 +34,17 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) redirect := session.RedirectURI
if fromRequest, ok := ender.Storage().(CanTerminateSessionFromRequest); ok {
redirect, err = fromRequest.TerminateSessionFromRequest(r.Context(), session)
} else {
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
}
if err != nil { if err != nil {
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
return return
} }
http.Redirect(w, r, session.RedirectURI, http.StatusFound) http.Redirect(w, r, redirect, http.StatusFound)
} }
func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) { func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
@ -60,11 +65,12 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
RedirectURI: ender.DefaultLogoutRedirectURI(), RedirectURI: ender.DefaultLogoutRedirectURI(),
} }
if req.IdTokenHint != "" { if req.IdTokenHint != "" {
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx)) claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
} }
session.UserID = claims.GetSubject() session.UserID = claims.GetSubject()
session.IDTokenHintClaims = claims
if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() { if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint") return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint")
} }

View file

@ -62,6 +62,14 @@ type AuthStorage interface {
KeySet(context.Context) ([]Key, error) KeySet(context.Context) ([]Key, error)
} }
// CanTerminateSessionFromRequest is an optional additional interface that may be implemented by
// implementors of Storage as an alternative to TerminateSession of the AuthStorage.
// It passes the complete parsed EndSessionRequest to the implementation, which allows access to additional data.
// It also allows to modify the uri, which will be used for redirection, (e.g. a UI where the user can consent to the logout)
type CanTerminateSessionFromRequest interface {
TerminateSessionFromRequest(ctx context.Context, endSessionRequest *EndSessionRequest) (string, error)
}
type ClientCredentialsStorage interface { type ClientCredentialsStorage interface {
ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error) ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error)
ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
@ -152,9 +160,10 @@ type StorageNotFoundError interface {
} }
type EndSessionRequest struct { type EndSessionRequest struct {
UserID string UserID string
ClientID string ClientID string
RedirectURI string IDTokenHintClaims *oidc.IDTokenClaims
RedirectURI string
} }
var ErrDuplicateUserCode = errors.New("user code already exists") var ErrDuplicateUserCode = errors.New("user code already exists")