diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 39e297e..23a3394 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -143,6 +143,9 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) { return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil } +func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error { + return nil +} func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan time.Time) { keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key} } @@ -233,6 +236,9 @@ func (c *ConfClient) RedirectURIs() []string { "https://op.certification.openid.net:62064/authz_post", } } +func (c *ConfClient) PostLogoutRedirectURIs() []string { + return []string{} +} func (c *ConfClient) LoginURL(id string) string { return "login?id=" + id diff --git a/pkg/oidc/session.go b/pkg/oidc/session.go new file mode 100644 index 0000000..418439e --- /dev/null +++ b/pkg/oidc/session.go @@ -0,0 +1,7 @@ +package oidc + +type EndSessionRequest struct { + IdTokenHint string `schema:"id_token_hint"` + PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"` + State string `schema:"state"` +} diff --git a/pkg/op/client.go b/pkg/op/client.go index 5422933..a61e31d 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -14,6 +14,7 @@ const ( type Client interface { GetID() string RedirectURIs() []string + PostLogoutRedirectURIs() []string ApplicationType() ApplicationType GetAuthMethod() AuthMethod LoginURL(string) string diff --git a/pkg/op/config.go b/pkg/op/config.go index 9333a5c..1b047db 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -12,6 +12,7 @@ type Configuration interface { AuthorizationEndpoint() Endpoint TokenEndpoint() Endpoint UserinfoEndpoint() Endpoint + EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint AuthMethodPostSupported() bool diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 95ecd39..3c89563 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "net/http" "time" @@ -10,6 +11,7 @@ import ( "github.com/caos/logging" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" ) const ( @@ -17,6 +19,7 @@ const ( defaulTokenEndpoint = "oauth/token" defaultIntrospectEndpoint = "introspect" defaultUserinfoEndpoint = "userinfo" + defaultEndSessionEndpoint = "end_session" defaultKeysEndpoint = "keys" AuthMethodBasic AuthMethod = "client_secret_basic" @@ -30,6 +33,7 @@ var ( Token: NewEndpoint(defaulTokenEndpoint), IntrospectionEndpoint: NewEndpoint(defaultIntrospectEndpoint), Userinfo: NewEndpoint(defaultUserinfoEndpoint), + EndSessionEndpoint: NewEndpoint(defaultEndSessionEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint), } ) @@ -39,6 +43,7 @@ type DefaultOP struct { endpoints *endpoints storage Storage signer Signer + verifier rp.Verifier crypto Crypto http *http.Server decoder *schema.Decoder @@ -49,8 +54,9 @@ type DefaultOP struct { } type Config struct { - Issuer string - CryptoKey [32]byte + Issuer string + CryptoKey [32]byte + DefaultLogoutRedirectURI string // ScopesSupported: oidc.SupportedScopes, // ResponseTypesSupported: responseTypes, // GrantTypesSupported: oidc.SupportedGrantTypes, @@ -164,6 +170,8 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts . p.signer = NewDefaultSigner(ctx, storage, keyCh) go p.ensureKey(ctx, storage, keyCh, p.timer) + p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience()) + router := CreateRouter(p, p.interceptor) p.http = &http.Server{ Addr: ":" + config.Port, @@ -195,6 +203,10 @@ func (p *DefaultOP) UserinfoEndpoint() Endpoint { return Endpoint(p.endpoints.Userinfo) } +func (p *DefaultOP) EndSessionEndpoint() Endpoint { + return Endpoint(p.endpoints.EndSessionEndpoint) +} + func (p *DefaultOP) KeysEndpoint() Endpoint { return Endpoint(p.endpoints.JwksURI) } @@ -215,6 +227,23 @@ func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) { Discover(w, CreateDiscoveryConfig(p, p.Signer())) } +func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + keySet, err := p.Storage().GetKeySet(ctx) + if err != nil { + return nil, errors.New("error fetching keys") + } + payload, err, ok := rp.CheckKey(keyID, keySet.Keys, jws) + if !ok { + return nil, errors.New("invalid kid") + } + return payload, err +} + func (p *DefaultOP) Decoder() *schema.Decoder { return p.decoder } @@ -257,7 +286,7 @@ func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Reque func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) { reqType := r.FormValue("grant_type") if reqType == "" { - ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing")) + RequestError(w, r, ErrInvalidRequest("grant_type missing")) return } if reqType == string(oidc.GrantTypeCode) { @@ -271,6 +300,17 @@ func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { Userinfo(w, r, p) } +func (p *DefaultOP) HandleEndSession(w http.ResponseWriter, r *http.Request) { + EndSession(w, r, p) +} + +func (p *DefaultOP) DefaultLogoutRedirectURI() string { + return p.config.DefaultLogoutRedirectURI +} +func (p *DefaultOP) IDTokenVerifier() rp.Verifier { + return p.verifier +} + func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey, timer <-chan time.Time) { count := 0 timer = time.After(0) diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 3d4ea98..fd6e0a6 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -17,8 +17,8 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()), TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), // IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()), - UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), - // EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint), + UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), + EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()), // CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe), JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), ScopesSupported: Scopes(c), diff --git a/pkg/op/error.go b/pkg/op/error.go index c6e702e..f3c5857 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -76,7 +76,7 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq http.Redirect(w, r, url, http.StatusFound) } -func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { +func RequestError(w http.ResponseWriter, r *http.Request, err error) { e, ok := err.(*OAuthError) if !ok { e = new(OAuthError) diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index b0b9244..e2f1c11 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -118,6 +118,20 @@ func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0) } +// PostLogoutRedirectURIs mocks base method +func (m *MockClient) PostLogoutRedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostLogoutRedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs +func (mr *MockClientMockRecorder) PostLogoutRedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockClient)(nil).PostLogoutRedirectURIs)) +} + // RedirectURIs mocks base method func (m *MockClient) RedirectURIs() []string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 7148c6d..c6174ff 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -61,6 +61,20 @@ func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint)) } +// EndSessionEndpoint mocks base method +func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EndSessionEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// EndSessionEndpoint indicates an expected call of EndSessionEndpoint +func (mr *MockConfigurationMockRecorder) EndSessionEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndSessionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).EndSessionEndpoint)) +} + // Issuer mocks base method func (m *MockConfiguration) Issuer() string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 14f106e..ac06842 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -195,3 +195,17 @@ func (mr *MockStorageMockRecorder) SaveNewKeyPair(arg0 interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNewKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveNewKeyPair), arg0) } + +// TerminateSession mocks base method +func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TerminateSession", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// TerminateSession indicates an expected call of TerminateSession +func (mr *MockStorageMockRecorder) TerminateSession(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateSession", reflect.TypeOf((*MockStorage)(nil).TerminateSession), arg0, arg1, arg2) +} diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index e4328c7..c9c63c6 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -126,6 +126,9 @@ func (c *ConfClient) RedirectURIs() []string { "custom://callback", } } +func (c *ConfClient) PostLogoutRedirectURIs() []string { + return []string{} +} func (c *ConfClient) LoginURL(id string) string { return "login?id=" + id diff --git a/pkg/op/op.go b/pkg/op/op.go index e6cdeb4..a926d34 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -24,6 +24,7 @@ type OpenIDProvider interface { HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request) + HandleEndSession(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request) HttpHandler() *http.Server } @@ -49,6 +50,7 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback)) router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange)) router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) + router.HandleFunc(o.EndSessionEndpoint().Relative(), h(o.HandleEndSession)) router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) return router } diff --git a/pkg/op/session.go b/pkg/op/session.go index 5e19040..96ec1bf 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -1,6 +1,76 @@ package op -import "github.com/caos/oidc/pkg/oidc" +import ( + "context" + "net/http" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" + "github.com/gorilla/schema" +) + +type SessionEnder interface { + Decoder() *schema.Decoder + Storage() Storage + IDTokenVerifier() rp.Verifier + DefaultLogoutRedirectURI() string +} + +func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { + req, err := ParseEndSessionRequest(r, ender.Decoder()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + session, err := ValidateEndSessionRequest(r.Context(), req, ender) + if err != nil { + RequestError(w, r, err) + return + } + err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.Client.GetID()) + if err != nil { + RequestError(w, r, ErrServerError("error terminating session")) + return + } + http.Redirect(w, r, session.RedirectURI, http.StatusFound) +} + +func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, ErrInvalidRequest("error parsing form") + } + req := new(oidc.EndSessionRequest) + err = decoder.Decode(req, r.Form) + if err != nil { + return nil, ErrInvalidRequest("error decoding form") + } + return req, nil +} + +func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) { + session := new(EndSessionRequest) + claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint) + if err != nil { + return nil, ErrInvalidRequest("id_token_hint invalid") + } + session.UserID = claims.Subject + session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty) + if err != nil { + return nil, ErrServerError("") + } + if req.PostLogoutRedirectURI == "" { + session.RedirectURI = ender.DefaultLogoutRedirectURI() + return session, nil + } + for _, uri := range session.Client.PostLogoutRedirectURIs() { + if uri == req.PostLogoutRedirectURI { + session.RedirectURI = uri + "?state=" + req.State + return session, nil + } + } + return nil, ErrInvalidRequest("post_logout_redirect_uri invalid") +} func NeedsExistingSession(authRequest *oidc.AuthRequest) bool { if authRequest == nil { diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 5d6725d..b770360 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -16,6 +16,8 @@ type AuthStorage interface { CreateToken(context.Context, AuthRequest) (string, time.Time, error) + TerminateSession(context.Context, string, string) error + GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan time.Time) GetKeySet(context.Context) (*jose.JSONWebKeySet, error) SaveNewKeyPair(context.Context) error @@ -53,3 +55,9 @@ type AuthRequest interface { GetSubject() string Done() bool } + +type EndSessionRequest struct { + UserID string + Client Client + RedirectURI string +} diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index cdd5396..cce3564 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -23,25 +23,25 @@ type Exchanger interface { func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) } if tokenReq.Code == "" { - ExchangeRequestError(w, r, ErrInvalidRequest("code missing")) + RequestError(w, r, ErrInvalidRequest("code missing")) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) return } err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID()) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) return } utils.MarshalJSON(w, resp) @@ -132,12 +132,12 @@ func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenRequest, err := ParseTokenExchangeRequest(w, r) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) return } err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage()) if err != nil { - ExchangeRequestError(w, r, err) + RequestError(w, r, err) return } } diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go index 58adddb..64ecaa0 100644 --- a/pkg/rp/default_verifier.go +++ b/pkg/rp/default_verifier.go @@ -46,6 +46,13 @@ func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts .. return &DefaultVerifier{config: conf, keySet: keySet} } +//WithIgnoreAudience will turn off audience claim (should only be used for id_token_hints) +func WithIgnoreAudience() func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.ignoreAudience = true + } +} + //WithIgnoreIssuedAt will turn off iat claim verification func WithIgnoreIssuedAt() func(*verifierConfig) { return func(conf *verifierConfig) { @@ -100,6 +107,7 @@ type verifierConfig struct { issuer string clientID string nonce string + ignoreAudience bool iat *iatConfig acr ACRVerifier maxAge time.Duration @@ -233,6 +241,9 @@ func (v *DefaultVerifier) checkIssuer(issuer string) error { } func (v *DefaultVerifier) checkAudience(audiences []string) error { + if v.config.ignoreAudience { + return nil + } if !utils.Contains(audiences, v.config.clientID) { return ErrAudienceMissingClientID(v.config.clientID) } @@ -244,6 +255,9 @@ func (v *DefaultVerifier) checkAudience(audiences []string) error { //4. if multiple aud strings --> check if azp //5. if azp --> check azp == client_id func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error { + if v.config.ignoreAudience { + return nil + } if len(audiences) > 1 { if authorizedParty == "" { return ErrAzpMissing()