Merge pull request #14 from caos/feat/end-session
feat: terminate session (front channel logout)
This commit is contained in:
commit
3d46c17fa0
16 changed files with 208 additions and 14 deletions
|
@ -143,6 +143,9 @@ func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequ
|
||||||
func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) {
|
func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) {
|
||||||
return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil
|
return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil
|
||||||
}
|
}
|
||||||
|
func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan time.Time) {
|
func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan time.Time) {
|
||||||
keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key}
|
keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key}
|
||||||
}
|
}
|
||||||
|
@ -233,6 +236,9 @@ func (c *ConfClient) RedirectURIs() []string {
|
||||||
"https://op.certification.openid.net:62064/authz_post",
|
"https://op.certification.openid.net:62064/authz_post",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (c *ConfClient) PostLogoutRedirectURIs() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ConfClient) LoginURL(id string) string {
|
func (c *ConfClient) LoginURL(id string) string {
|
||||||
return "login?id=" + id
|
return "login?id=" + id
|
||||||
|
|
7
pkg/oidc/session.go
Normal file
7
pkg/oidc/session.go
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
type EndSessionRequest struct {
|
||||||
|
IdTokenHint string `schema:"id_token_hint"`
|
||||||
|
PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
|
||||||
|
State string `schema:"state"`
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ const (
|
||||||
type Client interface {
|
type Client interface {
|
||||||
GetID() string
|
GetID() string
|
||||||
RedirectURIs() []string
|
RedirectURIs() []string
|
||||||
|
PostLogoutRedirectURIs() []string
|
||||||
ApplicationType() ApplicationType
|
ApplicationType() ApplicationType
|
||||||
GetAuthMethod() AuthMethod
|
GetAuthMethod() AuthMethod
|
||||||
LoginURL(string) string
|
LoginURL(string) string
|
||||||
|
|
|
@ -12,6 +12,7 @@ type Configuration interface {
|
||||||
AuthorizationEndpoint() Endpoint
|
AuthorizationEndpoint() Endpoint
|
||||||
TokenEndpoint() Endpoint
|
TokenEndpoint() Endpoint
|
||||||
UserinfoEndpoint() Endpoint
|
UserinfoEndpoint() Endpoint
|
||||||
|
EndSessionEndpoint() Endpoint
|
||||||
KeysEndpoint() Endpoint
|
KeysEndpoint() Endpoint
|
||||||
|
|
||||||
AuthMethodPostSupported() bool
|
AuthMethodPostSupported() bool
|
||||||
|
|
|
@ -2,6 +2,7 @@ package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/caos/logging"
|
"github.com/caos/logging"
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
"github.com/caos/oidc/pkg/rp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -17,6 +19,7 @@ const (
|
||||||
defaulTokenEndpoint = "oauth/token"
|
defaulTokenEndpoint = "oauth/token"
|
||||||
defaultIntrospectEndpoint = "introspect"
|
defaultIntrospectEndpoint = "introspect"
|
||||||
defaultUserinfoEndpoint = "userinfo"
|
defaultUserinfoEndpoint = "userinfo"
|
||||||
|
defaultEndSessionEndpoint = "end_session"
|
||||||
defaultKeysEndpoint = "keys"
|
defaultKeysEndpoint = "keys"
|
||||||
|
|
||||||
AuthMethodBasic AuthMethod = "client_secret_basic"
|
AuthMethodBasic AuthMethod = "client_secret_basic"
|
||||||
|
@ -30,6 +33,7 @@ var (
|
||||||
Token: NewEndpoint(defaulTokenEndpoint),
|
Token: NewEndpoint(defaulTokenEndpoint),
|
||||||
IntrospectionEndpoint: NewEndpoint(defaultIntrospectEndpoint),
|
IntrospectionEndpoint: NewEndpoint(defaultIntrospectEndpoint),
|
||||||
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
|
||||||
|
EndSessionEndpoint: NewEndpoint(defaultEndSessionEndpoint),
|
||||||
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
JwksURI: NewEndpoint(defaultKeysEndpoint),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -39,6 +43,7 @@ type DefaultOP struct {
|
||||||
endpoints *endpoints
|
endpoints *endpoints
|
||||||
storage Storage
|
storage Storage
|
||||||
signer Signer
|
signer Signer
|
||||||
|
verifier rp.Verifier
|
||||||
crypto Crypto
|
crypto Crypto
|
||||||
http *http.Server
|
http *http.Server
|
||||||
decoder *schema.Decoder
|
decoder *schema.Decoder
|
||||||
|
@ -49,8 +54,9 @@ type DefaultOP struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Issuer string
|
Issuer string
|
||||||
CryptoKey [32]byte
|
CryptoKey [32]byte
|
||||||
|
DefaultLogoutRedirectURI string
|
||||||
// ScopesSupported: oidc.SupportedScopes,
|
// ScopesSupported: oidc.SupportedScopes,
|
||||||
// ResponseTypesSupported: responseTypes,
|
// ResponseTypesSupported: responseTypes,
|
||||||
// GrantTypesSupported: oidc.SupportedGrantTypes,
|
// GrantTypesSupported: oidc.SupportedGrantTypes,
|
||||||
|
@ -164,6 +170,8 @@ func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts .
|
||||||
p.signer = NewDefaultSigner(ctx, storage, keyCh)
|
p.signer = NewDefaultSigner(ctx, storage, keyCh)
|
||||||
go p.ensureKey(ctx, storage, keyCh, p.timer)
|
go p.ensureKey(ctx, storage, keyCh, p.timer)
|
||||||
|
|
||||||
|
p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience())
|
||||||
|
|
||||||
router := CreateRouter(p, p.interceptor)
|
router := CreateRouter(p, p.interceptor)
|
||||||
p.http = &http.Server{
|
p.http = &http.Server{
|
||||||
Addr: ":" + config.Port,
|
Addr: ":" + config.Port,
|
||||||
|
@ -195,6 +203,10 @@ func (p *DefaultOP) UserinfoEndpoint() Endpoint {
|
||||||
return Endpoint(p.endpoints.Userinfo)
|
return Endpoint(p.endpoints.Userinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DefaultOP) EndSessionEndpoint() Endpoint {
|
||||||
|
return Endpoint(p.endpoints.EndSessionEndpoint)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *DefaultOP) KeysEndpoint() Endpoint {
|
func (p *DefaultOP) KeysEndpoint() Endpoint {
|
||||||
return Endpoint(p.endpoints.JwksURI)
|
return Endpoint(p.endpoints.JwksURI)
|
||||||
}
|
}
|
||||||
|
@ -215,6 +227,23 @@ func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||||
Discover(w, CreateDiscoveryConfig(p, p.Signer()))
|
Discover(w, CreateDiscoveryConfig(p, p.Signer()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||||
|
keyID := ""
|
||||||
|
for _, sig := range jws.Signatures {
|
||||||
|
keyID = sig.Header.KeyID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
keySet, err := p.Storage().GetKeySet(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("error fetching keys")
|
||||||
|
}
|
||||||
|
payload, err, ok := rp.CheckKey(keyID, keySet.Keys, jws)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid kid")
|
||||||
|
}
|
||||||
|
return payload, err
|
||||||
|
}
|
||||||
|
|
||||||
func (p *DefaultOP) Decoder() *schema.Decoder {
|
func (p *DefaultOP) Decoder() *schema.Decoder {
|
||||||
return p.decoder
|
return p.decoder
|
||||||
}
|
}
|
||||||
|
@ -257,7 +286,7 @@ func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Reque
|
||||||
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
|
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
|
||||||
reqType := r.FormValue("grant_type")
|
reqType := r.FormValue("grant_type")
|
||||||
if reqType == "" {
|
if reqType == "" {
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing"))
|
RequestError(w, r, ErrInvalidRequest("grant_type missing"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if reqType == string(oidc.GrantTypeCode) {
|
if reqType == string(oidc.GrantTypeCode) {
|
||||||
|
@ -271,6 +300,17 @@ func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
||||||
Userinfo(w, r, p)
|
Userinfo(w, r, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *DefaultOP) HandleEndSession(w http.ResponseWriter, r *http.Request) {
|
||||||
|
EndSession(w, r, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DefaultOP) DefaultLogoutRedirectURI() string {
|
||||||
|
return p.config.DefaultLogoutRedirectURI
|
||||||
|
}
|
||||||
|
func (p *DefaultOP) IDTokenVerifier() rp.Verifier {
|
||||||
|
return p.verifier
|
||||||
|
}
|
||||||
|
|
||||||
func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey, timer <-chan time.Time) {
|
func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey, timer <-chan time.Time) {
|
||||||
count := 0
|
count := 0
|
||||||
timer = time.After(0)
|
timer = time.After(0)
|
||||||
|
|
|
@ -17,8 +17,8 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
|
||||||
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
|
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
|
||||||
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
|
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
|
||||||
// IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
|
// IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
|
||||||
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
|
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
|
||||||
// EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint),
|
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
|
||||||
// CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
|
// CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
|
||||||
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
|
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
|
||||||
ScopesSupported: Scopes(c),
|
ScopesSupported: Scopes(c),
|
||||||
|
|
|
@ -76,7 +76,7 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
||||||
http.Redirect(w, r, url, http.StatusFound)
|
http.Redirect(w, r, url, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
e, ok := err.(*OAuthError)
|
e, ok := err.(*OAuthError)
|
||||||
if !ok {
|
if !ok {
|
||||||
e = new(OAuthError)
|
e = new(OAuthError)
|
||||||
|
|
|
@ -118,6 +118,20 @@ func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PostLogoutRedirectURIs mocks base method
|
||||||
|
func (m *MockClient) PostLogoutRedirectURIs() []string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "PostLogoutRedirectURIs")
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs
|
||||||
|
func (mr *MockClientMockRecorder) PostLogoutRedirectURIs() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockClient)(nil).PostLogoutRedirectURIs))
|
||||||
|
}
|
||||||
|
|
||||||
// RedirectURIs mocks base method
|
// RedirectURIs mocks base method
|
||||||
func (m *MockClient) RedirectURIs() []string {
|
func (m *MockClient) RedirectURIs() []string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -61,6 +61,20 @@ func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EndSessionEndpoint mocks base method
|
||||||
|
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
||||||
|
ret0, _ := ret[0].(op.Endpoint)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndSessionEndpoint indicates an expected call of EndSessionEndpoint
|
||||||
|
func (mr *MockConfigurationMockRecorder) EndSessionEndpoint() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndSessionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).EndSessionEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
// Issuer mocks base method
|
// Issuer mocks base method
|
||||||
func (m *MockConfiguration) Issuer() string {
|
func (m *MockConfiguration) Issuer() string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -195,3 +195,17 @@ func (mr *MockStorageMockRecorder) SaveNewKeyPair(arg0 interface{}) *gomock.Call
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNewKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveNewKeyPair), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNewKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveNewKeyPair), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TerminateSession mocks base method
|
||||||
|
func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "TerminateSession", arg0, arg1, arg2)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// TerminateSession indicates an expected call of TerminateSession
|
||||||
|
func (mr *MockStorageMockRecorder) TerminateSession(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateSession", reflect.TypeOf((*MockStorage)(nil).TerminateSession), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
|
@ -126,6 +126,9 @@ func (c *ConfClient) RedirectURIs() []string {
|
||||||
"custom://callback",
|
"custom://callback",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func (c *ConfClient) PostLogoutRedirectURIs() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ConfClient) LoginURL(id string) string {
|
func (c *ConfClient) LoginURL(id string) string {
|
||||||
return "login?id=" + id
|
return "login?id=" + id
|
||||||
|
|
|
@ -24,6 +24,7 @@ type OpenIDProvider interface {
|
||||||
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
||||||
HandleExchange(w http.ResponseWriter, r *http.Request)
|
HandleExchange(w http.ResponseWriter, r *http.Request)
|
||||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||||
|
HandleEndSession(w http.ResponseWriter, r *http.Request)
|
||||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||||
HttpHandler() *http.Server
|
HttpHandler() *http.Server
|
||||||
}
|
}
|
||||||
|
@ -49,6 +50,7 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router {
|
||||||
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback))
|
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback))
|
||||||
router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange))
|
router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange))
|
||||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
|
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
|
||||||
|
router.HandleFunc(o.EndSessionEndpoint().Relative(), h(o.HandleEndSession))
|
||||||
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,76 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import "github.com/caos/oidc/pkg/oidc"
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
"github.com/caos/oidc/pkg/rp"
|
||||||
|
"github.com/gorilla/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionEnder interface {
|
||||||
|
Decoder() *schema.Decoder
|
||||||
|
Storage() Storage
|
||||||
|
IDTokenVerifier() rp.Verifier
|
||||||
|
DefaultLogoutRedirectURI() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
|
||||||
|
req, err := ParseEndSessionRequest(r, ender.Decoder())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
|
||||||
|
if err != nil {
|
||||||
|
RequestError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.Client.GetID())
|
||||||
|
if err != nil {
|
||||||
|
RequestError(w, r, ErrServerError("error terminating session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, session.RedirectURI, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) {
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidRequest("error parsing form")
|
||||||
|
}
|
||||||
|
req := new(oidc.EndSessionRequest)
|
||||||
|
err = decoder.Decode(req, r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidRequest("error decoding form")
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
|
||||||
|
session := new(EndSessionRequest)
|
||||||
|
claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidRequest("id_token_hint invalid")
|
||||||
|
}
|
||||||
|
session.UserID = claims.Subject
|
||||||
|
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrServerError("")
|
||||||
|
}
|
||||||
|
if req.PostLogoutRedirectURI == "" {
|
||||||
|
session.RedirectURI = ender.DefaultLogoutRedirectURI()
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
for _, uri := range session.Client.PostLogoutRedirectURIs() {
|
||||||
|
if uri == req.PostLogoutRedirectURI {
|
||||||
|
session.RedirectURI = uri + "?state=" + req.State
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrInvalidRequest("post_logout_redirect_uri invalid")
|
||||||
|
}
|
||||||
|
|
||||||
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
|
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
|
||||||
if authRequest == nil {
|
if authRequest == nil {
|
||||||
|
|
|
@ -16,6 +16,8 @@ type AuthStorage interface {
|
||||||
|
|
||||||
CreateToken(context.Context, AuthRequest) (string, time.Time, error)
|
CreateToken(context.Context, AuthRequest) (string, time.Time, error)
|
||||||
|
|
||||||
|
TerminateSession(context.Context, string, string) error
|
||||||
|
|
||||||
GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan time.Time)
|
GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan time.Time)
|
||||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||||
SaveNewKeyPair(context.Context) error
|
SaveNewKeyPair(context.Context) error
|
||||||
|
@ -53,3 +55,9 @@ type AuthRequest interface {
|
||||||
GetSubject() string
|
GetSubject() string
|
||||||
Done() bool
|
Done() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EndSessionRequest struct {
|
||||||
|
UserID string
|
||||||
|
Client Client
|
||||||
|
RedirectURI string
|
||||||
|
}
|
||||||
|
|
|
@ -23,25 +23,25 @@ type Exchanger interface {
|
||||||
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
}
|
}
|
||||||
if tokenReq.Code == "" {
|
if tokenReq.Code == "" {
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
RequestError(w, r, ErrInvalidRequest("code missing"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
|
resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
|
@ -132,12 +132,12 @@ func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage
|
||||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
|
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,13 @@ func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ..
|
||||||
return &DefaultVerifier{config: conf, keySet: keySet}
|
return &DefaultVerifier{config: conf, keySet: keySet}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//WithIgnoreAudience will turn off audience claim (should only be used for id_token_hints)
|
||||||
|
func WithIgnoreAudience() func(*verifierConfig) {
|
||||||
|
return func(conf *verifierConfig) {
|
||||||
|
conf.ignoreAudience = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//WithIgnoreIssuedAt will turn off iat claim verification
|
//WithIgnoreIssuedAt will turn off iat claim verification
|
||||||
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
||||||
return func(conf *verifierConfig) {
|
return func(conf *verifierConfig) {
|
||||||
|
@ -100,6 +107,7 @@ type verifierConfig struct {
|
||||||
issuer string
|
issuer string
|
||||||
clientID string
|
clientID string
|
||||||
nonce string
|
nonce string
|
||||||
|
ignoreAudience bool
|
||||||
iat *iatConfig
|
iat *iatConfig
|
||||||
acr ACRVerifier
|
acr ACRVerifier
|
||||||
maxAge time.Duration
|
maxAge time.Duration
|
||||||
|
@ -233,6 +241,9 @@ func (v *DefaultVerifier) checkIssuer(issuer string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *DefaultVerifier) checkAudience(audiences []string) error {
|
func (v *DefaultVerifier) checkAudience(audiences []string) error {
|
||||||
|
if v.config.ignoreAudience {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if !utils.Contains(audiences, v.config.clientID) {
|
if !utils.Contains(audiences, v.config.clientID) {
|
||||||
return ErrAudienceMissingClientID(v.config.clientID)
|
return ErrAudienceMissingClientID(v.config.clientID)
|
||||||
}
|
}
|
||||||
|
@ -244,6 +255,9 @@ func (v *DefaultVerifier) checkAudience(audiences []string) error {
|
||||||
//4. if multiple aud strings --> check if azp
|
//4. if multiple aud strings --> check if azp
|
||||||
//5. if azp --> check azp == client_id
|
//5. if azp --> check azp == client_id
|
||||||
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
|
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
|
||||||
|
if v.config.ignoreAudience {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if len(audiences) > 1 {
|
if len(audiences) > 1 {
|
||||||
if authorizedParty == "" {
|
if authorizedParty == "" {
|
||||||
return ErrAzpMissing()
|
return ErrAzpMissing()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue