diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 5899af0..00033d5 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -9,7 +9,7 @@ import ( "net/url" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "github.com/zitadel/logging" "golang.org/x/exp/slog" @@ -67,19 +67,26 @@ type RelyingParty interface { // IDTokenVerifier returns the verifier used for oidc id_token verification IDTokenVerifier() *IDTokenVerifier - // ErrorHandler returns the handler used for callback errors + // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + // UnauthorizedHandler returns the handler used for unauthorized errors + UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) + // Logger from the context, or a fallback if set. Logger(context.Context) (logger *slog.Logger, ok bool) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) +type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string) var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) } +var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) { + http.Error(w, desc, http.StatusUnauthorized) +} type relyingParty struct { issuer string @@ -92,11 +99,12 @@ type relyingParty struct { httpClient *http.Client cookieHandler *httphelper.CookieHandler - errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier *IDTokenVerifier - verifierOpts []VerifierOption - signer jose.Signer - logger *slog.Logger + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string) + idTokenVerifier *IDTokenVerifier + verifierOpts []VerifierOption + signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -157,6 +165,13 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) { + if rp.unauthorizedHandler == nil { + rp.unauthorizedHandler = DefaultUnauthorizedHandler + } + return rp.unauthorizedHandler +} + func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { logger, ok = logging.FromContext(ctx) if ok { @@ -269,6 +284,13 @@ func WithErrorHandler(errorHandler ErrorHandler) Option { } } +func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option { + return func(rp *relyingParty) error { + rp.unauthorizedHandler = unauthorizedHandler + return nil + } +} + func WithVerifierOpts(opts ...VerifierOption) Option { return func(rp *relyingParty) error { rp.verifierOpts = opts @@ -356,13 +378,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam state := stateFn() if err := trySetStateCookie(w, state, rp); err != nil { - http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to create state cookie: "+err.Error(), state) return } if rp.IsPKCE() { codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) if err != nil { - http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to create code challenge: "+err.Error(), state) return } opts = append(opts, WithCodeChallenge(codeChallenge)) @@ -449,7 +471,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { - http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to get state: "+err.Error(), state) return } params := r.URL.Query() @@ -465,7 +487,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.IsPKCE() { codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) if err != nil { - http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to get code verifier: "+err.Error(), state) return } codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) @@ -474,14 +496,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.Signer() != nil { assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer()) if err != nil { - http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to build assertion: "+err.Error(), state) return } codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) } tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) if err != nil { - http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "failed to exchange token: "+err.Error(), state) return } callback(w, r, tokens, state, rp) @@ -501,7 +523,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { - http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) + rp.UnauthorizedHandler()(w, r, "userinfo failed: "+err.Error(), state) return } f(w, r, tokens, state, rp, info)