feat: add rp.RevokeToken (#231)

* feat: add rp.RevokeToken

* add missing lines after conflict resolving

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
David Sharnoff 2022-11-14 22:35:16 -08:00 committed by GitHub
parent 0847a5985a
commit 39852f6021
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 5 deletions

6
NEXT_RELEASE.md Normal file
View file

@ -0,0 +1,6 @@
# Backwards-incompatible changes to be made in the next major release
- Add `rp/RelyingParty.GetRevokeEndpoint`
- Rename `op/OpStorage.GetKeyByIDAndUserID` to `op/OpStorage.GetKeyByIDAndClientID`

View file

@ -255,11 +255,11 @@ func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID
// RevokeToken implements the op.Storage interface // RevokeToken implements the op.Storage interface
// it will be called after parsing and validation of the token revocation request // it will be called after parsing and validation of the token revocation request
func (s *Storage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error { func (s *Storage) RevokeToken(ctx context.Context, tokenIDOrToken string, userID string, clientID string) *oidc.Error {
// a single token was requested to be removed // a single token was requested to be removed
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
accessToken, ok := s.tokens[token] accessToken, ok := s.tokens[tokenIDOrToken] // tokenID
if ok { if ok {
if accessToken.ApplicationID != clientID { if accessToken.ApplicationID != clientID {
return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
@ -269,7 +269,7 @@ func (s *Storage) RevokeToken(ctx context.Context, token string, userID string,
delete(s.tokens, accessToken.ID) delete(s.tokens, accessToken.ID)
return nil return nil
} }
refreshToken, ok := s.refreshTokens[token] refreshToken, ok := s.refreshTokens[tokenIDOrToken] // token
if !ok { if !ok {
// if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of // if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of
// being not valid (anymore) is achieved // being not valid (anymore) is achieved

View file

@ -109,6 +109,47 @@ func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndS
return location, nil return location, nil
} }
type RevokeCaller interface {
GetRevokeEndpoint() string
HttpClient() *http.Client
}
type RevokeRequest struct {
Token string `schema:"token"`
TokenTypeHint string `schema:"token_type_hint"`
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"`
}
func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCaller) error {
req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn)
if err != nil {
return err
}
client := caller.HttpClient()
client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// According to RFC7009 in section 2.2:
// "The content of the response body is ignored by the client as all
// necessary information is conveyed in the response code."
if resp.StatusCode != 200 {
// TODO: switch to io.ReadAll when go1.15 support is retired
body, err := ioutil.ReadAll(resp.Body)
if err == nil {
return fmt.Errorf("revoke returned status %d and text: %s", resp.StatusCode, string(body))
} else {
return fmt.Errorf("revoke returned status %d", resp.StatusCode)
}
}
return nil
}
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) { func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := crypto.BytesToPrivateKey(key) privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil { if err != nil {

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -52,6 +53,9 @@ type RelyingParty interface {
// GetEndSessionEndpoint returns the endpoint to sign out on a IDP // GetEndSessionEndpoint returns the endpoint to sign out on a IDP
GetEndSessionEndpoint() string GetEndSessionEndpoint() string
// GetRevokeEndpoint returns the endpoint to revoke a specific token
// "GetRevokeEndpoint() string" will be added in a future release
// UserinfoEndpoint returns the userinfo // UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string UserinfoEndpoint() string
@ -121,6 +125,10 @@ func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL return rp.endpoints.EndSessionURL
} }
func (rp *relyingParty) GetRevokeEndpoint() string {
return rp.endpoints.RevokeURL
}
func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier {
if rp.idTokenVerifier == nil { if rp.idTokenVerifier == nil {
rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
@ -491,6 +499,7 @@ type Endpoints struct {
UserinfoURL string UserinfoURL string
JKWsURL string JKWsURL string
EndSessionURL string EndSessionURL string
RevokeURL string
} }
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@ -504,6 +513,7 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
UserinfoURL: discoveryConfig.UserinfoEndpoint, UserinfoURL: discoveryConfig.UserinfoEndpoint,
JKWsURL: discoveryConfig.JwksURI, JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint, EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint,
} }
} }
@ -584,3 +594,21 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str
} }
return client.CallEndSessionEndpoint(request, nil, rp) return client.CallEndSessionEndpoint(request, nil, rp)
} }
// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
// returned by NewRelyingPartyOIDC() meets that criteria, but the one returned by
// NewRelyingPartyOAuth() does not.
//
// tokenTypeHint should be either "id_token" or "refresh_token".
func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error {
request := client.RevokeRequest{
Token: token,
TokenTypeHint: tokenTypeHint,
ClientID: rp.OAuthConfig().ClientID,
ClientSecret: rp.OAuthConfig().ClientSecret,
}
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
return client.CallRevokeEndpoint(request, nil, rc)
}
return fmt.Errorf("RelyingParty does not support RevokeCaller")
}

View file

@ -39,7 +39,12 @@ type AuthStorage interface {
TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error) TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error)
TerminateSession(ctx context.Context, userID string, clientID string) error TerminateSession(ctx context.Context, userID string, clientID string) error
RevokeToken(ctx context.Context, tokenID string, userID string, clientID string) *oidc.Error
// RevokeToken should revoke a token. In the situation that the original request was to
// revoke an access token, then tokenOrTokenID will be a tokenID and userID will be set
// but if the original request was for a refresh token, then userID will be empty and
// tokenOrTokenID will be the refresh token, not its ID.
RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error
GetSigningKey(context.Context, chan<- jose.SigningKey) GetSigningKey(context.Context, chan<- jose.SigningKey)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error) GetKeySet(context.Context) (*jose.JSONWebKeySet, error)

View file

@ -113,8 +113,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
e := oidc.DefaultToServerError(err, err.Error()) e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient { switch e.ErrorType {
case oidc.InvalidClient:
status = 401 status = 401
case oidc.ServerError:
status = 500
} }
httphelper.MarshalJSONWithStatus(w, e, status) httphelper.MarshalJSONWithStatus(w, e, status)
} }