From dca7835b7c5142c79b3cea34ce81a41f603ef814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Fri, 5 Jan 2024 19:29:04 +0100 Subject: [PATCH] Use optional interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jan-Otto Kröpke --- pkg/client/rp/relying_party.go | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index eca4167..5e6949f 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -71,13 +71,15 @@ type RelyingParty interface { // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) - // UnauthorizedHandler returns the handler used for unauthorized errors - UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) - // Logger from the context, or a fallback if set. Logger(context.Context) (logger *slog.Logger, ok bool) } +type HasUnauthorizedHandler interface { + // UnauthorizedHandler returns the handler used for unauthorized errors + UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string) +} + type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string) @@ -376,13 +378,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam state := stateFn() if err := trySetStateCookie(w, state, rp); err != nil { - rp.UnauthorizedHandler()(w, r, "failed to create state cookie: "+err.Error(), state) + unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp) return } if rp.IsPKCE() { codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) if err != nil { - rp.UnauthorizedHandler()(w, r, "failed to create code challenge: "+err.Error(), state) + unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp) return } opts = append(opts, WithCodeChallenge(codeChallenge)) @@ -469,7 +471,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { - rp.UnauthorizedHandler()(w, r, "failed to get state: "+err.Error(), state) + unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp) return } params := r.URL.Query() @@ -485,7 +487,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.IsPKCE() { codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) if err != nil { - rp.UnauthorizedHandler()(w, r, "failed to get code verifier: "+err.Error(), state) + unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) @@ -494,14 +496,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.Signer() != nil { assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer()) if err != nil { - rp.UnauthorizedHandler()(w, r, "failed to build assertion: "+err.Error(), state) + unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) } tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) if err != nil { - rp.UnauthorizedHandler()(w, r, "failed to exchange token: "+err.Error(), state) + unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp) return } callback(w, r, tokens, state, rp) @@ -521,7 +523,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { - rp.UnauthorizedHandler()(w, r, "userinfo failed: "+err.Error(), state) + unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp) return } f(w, r, tokens, state, rp, info) @@ -748,3 +750,11 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi } return fmt.Errorf("RelyingParty does not support RevokeCaller") } + +func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) { + if rp, ok := rp.(HasUnauthorizedHandler); ok { + rp.UnauthorizedHandler()(w, r, desc, state) + return + } + http.Error(w, desc, http.StatusUnauthorized) +}