feat(rp): Add UnauthorizedHandler (#503)
* RP: Add UnauthorizedHandler Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> * remove race condition Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> * Use optional interface Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de> --------- Signed-off-by: Jan-Otto Kröpke <mail@jkroepke.de>
This commit is contained in:
parent
5dcf6de055
commit
984e31a9e2
1 changed files with 46 additions and 16 deletions
|
@ -66,19 +66,28 @@ type RelyingParty interface {
|
||||||
|
|
||||||
// IDTokenVerifier returns the verifier used for oidc id_token verification
|
// IDTokenVerifier returns the verifier used for oidc id_token verification
|
||||||
IDTokenVerifier() *IDTokenVerifier
|
IDTokenVerifier() *IDTokenVerifier
|
||||||
// ErrorHandler returns the handler used for callback errors
|
|
||||||
|
|
||||||
|
// ErrorHandler returns the handler used for callback errors
|
||||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||||
|
|
||||||
// Logger from the context, or a fallback if set.
|
// Logger from the context, or a fallback if set.
|
||||||
Logger(context.Context) (logger *slog.Logger, ok bool)
|
Logger(context.Context) (logger *slog.Logger, ok bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HasUnauthorizedHandler interface {
|
||||||
|
// UnauthorizedHandler returns the handler used for unauthorized errors
|
||||||
|
UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||||
|
}
|
||||||
|
|
||||||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
||||||
|
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||||
|
|
||||||
var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
||||||
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
|
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
|
||||||
|
http.Error(w, desc, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
type relyingParty struct {
|
type relyingParty struct {
|
||||||
issuer string
|
issuer string
|
||||||
|
@ -92,6 +101,7 @@ type relyingParty struct {
|
||||||
cookieHandler *httphelper.CookieHandler
|
cookieHandler *httphelper.CookieHandler
|
||||||
|
|
||||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||||
|
unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
|
||||||
idTokenVerifier *IDTokenVerifier
|
idTokenVerifier *IDTokenVerifier
|
||||||
verifierOpts []VerifierOption
|
verifierOpts []VerifierOption
|
||||||
signer jose.Signer
|
signer jose.Signer
|
||||||
|
@ -156,6 +166,10 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
|
||||||
return rp.errorHandler
|
return rp.errorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
|
||||||
|
return rp.unauthorizedHandler
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
|
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
|
||||||
logger, ok = logging.FromContext(ctx)
|
logger, ok = logging.FromContext(ctx)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -172,6 +186,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
|
||||||
oauthConfig: config,
|
oauthConfig: config,
|
||||||
httpClient: httphelper.DefaultHTTPClient,
|
httpClient: httphelper.DefaultHTTPClient,
|
||||||
oauth2Only: true,
|
oauth2Only: true,
|
||||||
|
unauthorizedHandler: DefaultUnauthorizedHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, optFunc := range options {
|
for _, optFunc := range options {
|
||||||
|
@ -268,6 +283,13 @@ func WithErrorHandler(errorHandler ErrorHandler) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
|
||||||
|
return func(rp *relyingParty) error {
|
||||||
|
rp.unauthorizedHandler = unauthorizedHandler
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithVerifierOpts(opts ...VerifierOption) Option {
|
func WithVerifierOpts(opts ...VerifierOption) Option {
|
||||||
return func(rp *relyingParty) error {
|
return func(rp *relyingParty) error {
|
||||||
rp.verifierOpts = opts
|
rp.verifierOpts = opts
|
||||||
|
@ -355,13 +377,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
|
||||||
|
|
||||||
state := stateFn()
|
state := stateFn()
|
||||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
if err := trySetStateCookie(w, state, rp); err != nil {
|
||||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rp.IsPKCE() {
|
if rp.IsPKCE() {
|
||||||
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
|
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||||
|
@ -448,7 +470,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
state, err := tryReadStateCookie(w, r, rp)
|
state, err := tryReadStateCookie(w, r, rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
params := r.URL.Query()
|
params := r.URL.Query()
|
||||||
|
@ -464,7 +486,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
||||||
if rp.IsPKCE() {
|
if rp.IsPKCE() {
|
||||||
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
||||||
|
@ -473,14 +495,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
||||||
if rp.Signer() != nil {
|
if rp.Signer() != nil {
|
||||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
|
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||||
}
|
}
|
||||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
callback(w, r, tokens, state, rp)
|
callback(w, r, tokens, state, rp)
|
||||||
|
@ -500,7 +522,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
|
||||||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||||
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f(w, r, tokens, state, rp, info)
|
f(w, r, tokens, state, rp, info)
|
||||||
|
@ -727,3 +749,11 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi
|
||||||
}
|
}
|
||||||
return ErrRelyingPartyNotSupportRevokeCaller
|
return ErrRelyingPartyNotSupportRevokeCaller
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
|
||||||
|
if rp, ok := rp.(HasUnauthorizedHandler); ok {
|
||||||
|
rp.UnauthorizedHandler()(w, r, desc, state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, desc, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue