Use optional interface
Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de>
This commit is contained in:
parent
caab666767
commit
dca7835b7c
1 changed files with 20 additions and 10 deletions
|
@ -71,13 +71,15 @@ type RelyingParty interface {
|
|||
// ErrorHandler returns the handler used for callback errors
|
||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
// UnauthorizedHandler returns the handler used for unauthorized errors
|
||||
UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string)
|
||||
|
||||
// Logger from the context, or a fallback if set.
|
||||
Logger(context.Context) (logger *slog.Logger, ok bool)
|
||||
}
|
||||
|
||||
type HasUnauthorizedHandler interface {
|
||||
// UnauthorizedHandler returns the handler used for unauthorized errors
|
||||
UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||
}
|
||||
|
||||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
||||
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||
|
||||
|
@ -376,13 +378,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
|
|||
|
||||
state := stateFn()
|
||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to create state cookie: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
if rp.IsPKCE() {
|
||||
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to create code challenge: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
|
@ -469,7 +471,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to get state: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
params := r.URL.Query()
|
||||
|
@ -485,7 +487,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
if rp.IsPKCE() {
|
||||
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to get code verifier: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
||||
|
@ -494,14 +496,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
if rp.Signer() != nil {
|
||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to build assertion: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||
}
|
||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "failed to exchange token: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
callback(w, r, tokens, state, rp)
|
||||
|
@ -521,7 +523,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
|
|||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
rp.UnauthorizedHandler()(w, r, "userinfo failed: "+err.Error(), state)
|
||||
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
f(w, r, tokens, state, rp, info)
|
||||
|
@ -748,3 +750,11 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi
|
|||
}
|
||||
return fmt.Errorf("RelyingParty does not support RevokeCaller")
|
||||
}
|
||||
|
||||
func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
|
||||
if rp, ok := rp.(HasUnauthorizedHandler); ok {
|
||||
rp.UnauthorizedHandler()(w, r, desc, state)
|
||||
return
|
||||
}
|
||||
http.Error(w, desc, http.StatusUnauthorized)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue