RP: Add UnauthorizedHandler

Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de>
This commit is contained in:
Jan-Otto Kröpke 2023-12-19 11:18:41 +01:00
parent 2b35eeb835
commit 010f41eefa
No known key found for this signature in database

View file

@ -9,7 +9,7 @@ import (
"net/url" "net/url"
"time" "time"
jose "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
@ -67,19 +67,26 @@ type RelyingParty interface {
// IDTokenVerifier returns the verifier used for oidc id_token verification // IDTokenVerifier returns the verifier used for oidc id_token verification
IDTokenVerifier() *IDTokenVerifier IDTokenVerifier() *IDTokenVerifier
// ErrorHandler returns the handler used for callback errors
// ErrorHandler returns the handler used for callback errors
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
// UnauthorizedHandler returns the handler used for unauthorized errors
UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string)
// Logger from the context, or a fallback if set. // Logger from the context, or a fallback if set.
Logger(context.Context) (logger *slog.Logger, ok bool) Logger(context.Context) (logger *slog.Logger, ok bool)
} }
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)
var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
} }
var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
http.Error(w, desc, http.StatusUnauthorized)
}
type relyingParty struct { type relyingParty struct {
issuer string issuer string
@ -93,6 +100,7 @@ type relyingParty struct {
cookieHandler *httphelper.CookieHandler cookieHandler *httphelper.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string) errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
idTokenVerifier *IDTokenVerifier idTokenVerifier *IDTokenVerifier
verifierOpts []VerifierOption verifierOpts []VerifierOption
signer jose.Signer signer jose.Signer
@ -157,6 +165,13 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
return rp.errorHandler return rp.errorHandler
} }
func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
if rp.unauthorizedHandler == nil {
rp.unauthorizedHandler = DefaultUnauthorizedHandler
}
return rp.unauthorizedHandler
}
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
logger, ok = logging.FromContext(ctx) logger, ok = logging.FromContext(ctx)
if ok { if ok {
@ -269,6 +284,13 @@ func WithErrorHandler(errorHandler ErrorHandler) Option {
} }
} }
func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
return func(rp *relyingParty) error {
rp.unauthorizedHandler = unauthorizedHandler
return nil
}
}
func WithVerifierOpts(opts ...VerifierOption) Option { func WithVerifierOpts(opts ...VerifierOption) Option {
return func(rp *relyingParty) error { return func(rp *relyingParty) error {
rp.verifierOpts = opts rp.verifierOpts = opts
@ -356,13 +378,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
state := stateFn() state := stateFn()
if err := trySetStateCookie(w, state, rp); err != nil { if err := trySetStateCookie(w, state, rp); err != nil {
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to create state cookie: "+err.Error(), state)
return return
} }
if rp.IsPKCE() { if rp.IsPKCE() {
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
if err != nil { if err != nil {
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to create code challenge: "+err.Error(), state)
return return
} }
opts = append(opts, WithCodeChallenge(codeChallenge)) opts = append(opts, WithCodeChallenge(codeChallenge))
@ -449,7 +471,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
state, err := tryReadStateCookie(w, r, rp) state, err := tryReadStateCookie(w, r, rp)
if err != nil { if err != nil {
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to get state: "+err.Error(), state)
return return
} }
params := r.URL.Query() params := r.URL.Query()
@ -465,7 +487,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
if rp.IsPKCE() { if rp.IsPKCE() {
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
if err != nil { if err != nil {
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to get code verifier: "+err.Error(), state)
return return
} }
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
@ -474,14 +496,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
if rp.Signer() != nil { if rp.Signer() != nil {
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer()) assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
if err != nil { if err != nil {
http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to build assertion: "+err.Error(), state)
return return
} }
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
} }
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
if err != nil { if err != nil {
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "failed to exchange token: "+err.Error(), state)
return return
} }
callback(w, r, tokens, state, rp) callback(w, r, tokens, state, rp)
@ -501,7 +523,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
if err != nil { if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) rp.UnauthorizedHandler()(w, r, "userinfo failed: "+err.Error(), state)
return return
} }
f(w, r, tokens, state, rp, info) f(w, r, tokens, state, rp, info)