From 9721c25336c9214048880ce3ed9c89bdb2d13781 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 26 Oct 2021 13:18:39 +0200 Subject: [PATCH] begin revocation --- pkg/oidc/revocation.go | 6 ++ pkg/op/config.go | 1 + pkg/op/discovery.go | 1 + pkg/op/op.go | 19 +++++- pkg/op/storage.go | 3 +- pkg/op/token_revocation.go | 116 +++++++++++++++++++++++++++++++++++++ 6 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 pkg/oidc/revocation.go create mode 100644 pkg/op/token_revocation.go diff --git a/pkg/oidc/revocation.go b/pkg/oidc/revocation.go new file mode 100644 index 0000000..0a56c61 --- /dev/null +++ b/pkg/oidc/revocation.go @@ -0,0 +1,6 @@ +package oidc + +type RevocationRequest struct { + Token string `schema:"token"` + TokenTypeHint string `schema:"token_type_hint"` +} diff --git a/pkg/op/config.go b/pkg/op/config.go index 1ec99c9..b3a1a2a 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -16,6 +16,7 @@ type Configuration interface { TokenEndpoint() Endpoint IntrospectionEndpoint() Endpoint UserinfoEndpoint() Endpoint + RevocationEndpoint() Endpoint EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index c625abc..f897735 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -24,6 +24,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()), UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), + RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()), EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()), JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), ScopesSupported: Scopes(c), diff --git a/pkg/op/op.go b/pkg/op/op.go index c3508bd..c2ba032 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -74,6 +74,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o))) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) + router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) return router @@ -95,6 +96,7 @@ type endpoints struct { Token Endpoint Introspection Endpoint Userinfo Endpoint + Revocation Endpoint EndSession Endpoint CheckSessionIframe Endpoint JwksURI Endpoint @@ -172,6 +174,10 @@ func (o *openidProvider) UserinfoEndpoint() Endpoint { return o.endpoints.Userinfo } +func (o *openidProvider) RevocationEndpoint() Endpoint { + return o.endpoints.Revocation +} + func (o *openidProvider) EndSessionEndpoint() Endpoint { return o.endpoints.EndSession } @@ -352,6 +358,16 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } +func WithCustomRevocationEndpoint(endpoint Endpoint) Option { + return func(o *openidProvider) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.Revocation = endpoint + return nil + } +} + func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { return func(o *openidProvider) error { if err := endpoint.Validate(); err != nil { @@ -372,11 +388,12 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, endSession, keys Endpoint) Option { +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { return func(o *openidProvider) error { o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo + o.endpoints.Revocation = revocation o.endpoints.EndSession = endSession o.endpoints.JwksURI = keys return nil diff --git a/pkg/op/storage.go b/pkg/op/storage.go index ca9ae7c..94c2a33 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -20,7 +20,8 @@ type AuthStorage interface { CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (RefreshTokenRequest, error) - TerminateSession(context.Context, string, string) error + TerminateSession(ctx context.Context, userID string, clientID string) error + RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error GetSigningKey(context.Context, chan<- jose.SigningKey) GetKeySet(context.Context) (*jose.JSONWebKeySet, error) diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go new file mode 100644 index 0000000..d754503 --- /dev/null +++ b/pkg/op/token_revocation.go @@ -0,0 +1,116 @@ +package op + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" + + httphelper "github.com/caos/oidc/pkg/http" + "github.com/caos/oidc/pkg/oidc" +) + +type Revoker interface { + Decoder() httphelper.Decoder + Crypto() Crypto + Storage() Storage + AccessTokenVerifier() AccessTokenVerifier +} + +type RevokerJWTProfile interface { + Revoker + JWTProfileVerifier() JWTProfileVerifier +} + +func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + Revoke(w, r, revoker) + } +} + +func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { + token, _, clientID, err := ParseTokenRevocationRequest(r, revoker) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + tokenID, subject, ok := getTokenIDAndSubjectForRevocation(r.Context(), revoker, token) + if ok { + err := revoker.Storage().RevokeToken(r.Context(), tokenID, subject, clientID) + if err != nil { + RevocationRequestError(w, r, err) + return + } + httphelper.MarshalJSON(w, err) + return + } + //if tokenID != "" { + // token = tokenID + //} + err = revoker.Storage().RevokeToken(r.Context(), token, subject, clientID) + httphelper.MarshalJSON(w, nil) +} + +func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) { + err = r.ParseForm() + if err != nil { + return "", "", "", errors.New("unable to parse request") + } + req := new(struct { + oidc.RevocationRequest + oidc.ClientAssertionParams + }) + err = revoker.Decoder().Decode(req, r.Form) + if err != nil { + return "", "", "", errors.New("unable to parse request") + } + if revokerJWTProfile, ok := revoker.(RevokerJWTProfile); ok && req.ClientAssertion != "" { + profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier()) + if err == nil { + return req.Token, req.TokenTypeHint, profile.Issuer, nil + } + return "", "", "", err + } + clientID, clientSecret, ok := r.BasicAuth() + if ok { + clientID, err = url.QueryUnescape(clientID) + if err != nil { + return "", "", "", errors.New("invalid basic auth header") + } + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return "", "", "", errors.New("invalid basic auth header") + } + if err := revoker.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil { + return "", "", "", err + } + return req.Token, req.TokenTypeHint, clientID, nil + } + return "", "", "", errors.New("invalid authorization") +} + +func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + e := oidc.DefaultToServerError(err, err.Error()) + status := http.StatusBadRequest + if e.ErrorType == oidc.InvalidClient { + status = 401 + } + httphelper.MarshalJSONWithStatus(w, e, status) +} + +func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) + if err == nil { + splitToken := strings.Split(tokenIDSubject, ":") + if len(splitToken) != 2 { + return "", "", false + } + return splitToken[0], splitToken[1], true + } + accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + return "", "", false + } + return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true +}