diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index e3c8f33..a425c3d 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -1,6 +1,7 @@ package mock import ( + "context" "crypto/rand" "crypto/rsa" "errors" @@ -107,7 +108,7 @@ var ( t bool ) -func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { +func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) { a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} if authReq.CodeChallenge != "" { a.CodeChallenge = &oidc.CodeChallenge{ @@ -118,26 +119,26 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque t = false return a, nil } -func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) { +func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) { return a, nil } -func (s *AuthStorage) DeleteAuthRequest(string) error { +func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error { t = true return nil } -func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { +func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) { if id != "id" || t { return nil, errors.New("not found") } return a, nil } -func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { +func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) { return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil } -func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) { +func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) { return s.key, nil } -func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { +func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { pubkey := s.key.Public() return &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ @@ -146,7 +147,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { }, nil } -func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) { +func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) { if id == "none" { return nil, errors.New("not found") } @@ -165,11 +166,11 @@ func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) { return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil } -func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error { +func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error { return nil } -func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) { +func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) { return &oidc.Userinfo{ Subject: a.GetSubject(), Address: &oidc.UserinfoAddress{ diff --git a/example/server/default/default.go b/example/server/default/default.go index 3ad6feb..0b0bb8e 100644 --- a/example/server/default/default.go +++ b/example/server/default/default.go @@ -21,7 +21,7 @@ func main() { Port: "9998", } storage := mock.NewAuthStorage() - handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test")) + handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test")) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index b35882a..da7059a 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect - golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 // indirect + golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect golang.org/x/text v0.3.2 diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 6ebae5c..b041c34 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -2,7 +2,6 @@ package oidc import ( "encoding/json" - "strings" "time" "github.com/caos/oidc/pkg/utils" @@ -33,9 +32,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { return err } audience := i.Audiences - if len(audience) == 1 { - audience = strings.Split(audience[0], " ") - } + // if len(audience) == 1 { + // audience = strings.Split(audience[0], " ") + // } t.Issuer = i.Issuer t.Subject = i.Subject t.Audiences = audience diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index bdfa585..fd41e84 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -1,6 +1,8 @@ package op import ( + "context" + "errors" "fmt" "net/http" "strings" @@ -24,7 +26,7 @@ type Authorizer interface { type ValidationAuthorizer interface { Authorizer - ValidateAuthRequest(*oidc.AuthRequest, Storage) error + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error } func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { @@ -45,18 +47,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { if validater, ok := authorizer.(ValidationAuthorizer); ok { validation = validater.ValidateAuthRequest } - if err := validation(authReq, authorizer.Storage()); err != nil { + if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - req, err := authorizer.Storage().CreateAuthRequest(authReq) + req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - client, err := authorizer.Storage().GetClientByClientID(req.GetClientID()) + client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) if err != nil { AuthRequestError(w, r, req, err, authorizer.Encoder()) return @@ -64,11 +66,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } -func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { return err } - if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { + if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { return err } if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil { @@ -93,11 +95,11 @@ func ValidateAuthReqScopes(scopes []string) error { return nil } -func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { +func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { if uri == "" { return ErrInvalidRequestRedirectURI("redirect_uri must not be empty") } - client, err := storage.GetClientByClientID(client_id) + client, err := storage.GetClientByClientID(ctx, client_id) if err != nil { return ErrServerError(err.Error()) } @@ -142,11 +144,15 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author params := mux.Vars(r) id := params["id"] - authReq, err := authorizer.Storage().AuthRequestByID(id) + authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { AuthRequestError(w, r, nil, err, authorizer.Encoder()) return } + if !authReq.Done() { + AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder()) + return + } AuthResponse(authReq, authorizer, w, r) } diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index baccc2d..5acae08 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -1,6 +1,7 @@ package op import ( + "context" "net/http" "time" @@ -101,7 +102,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts { } } -func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { +func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { err := ValidateIssuer(config.Issuer) if err != nil { return nil, err @@ -113,7 +114,7 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope endpoints: DefaultEndpoints, } - p.signer, err = NewDefaultSigner(storage) + p.signer, err = NewDefaultSigner(ctx, storage) if err != nil { return nil, err } diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 018b040..8e2052b 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -11,7 +11,7 @@ type KeyProvider interface { } func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { - keySet, err := k.Storage().GetKeySet() + keySet, err := k.Storage().GetKeySet(r.Context()) if err != nil { } diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index ee85922..35eaa39 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -5,6 +5,7 @@ package mock import ( + context "context" oidc "github.com/caos/oidc/pkg/oidc" op "github.com/caos/oidc/pkg/op" gomock "github.com/golang/mock/gomock" @@ -36,119 +37,119 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder { } // AuthRequestByID mocks base method -func (m *MockStorage) AuthRequestByID(arg0 string) (op.AuthRequest, error) { +func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AuthRequestByID", arg0) + ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1) ret0, _ := ret[0].(op.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // AuthRequestByID indicates an expected call of AuthRequestByID -func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1) } // AuthorizeClientIDSecret mocks base method -func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error { +func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1) + ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret -func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2) } // CreateAuthRequest mocks base method -func (m *MockStorage) CreateAuthRequest(arg0 *oidc.AuthRequest) (op.AuthRequest, error) { +func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateAuthRequest", arg0) + ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1) ret0, _ := ret[0].(op.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateAuthRequest indicates an expected call of CreateAuthRequest -func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1) } // DeleteAuthRequest mocks base method -func (m *MockStorage) DeleteAuthRequest(arg0 string) error { +func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0) + ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAuthRequest indicates an expected call of DeleteAuthRequest -func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1) } // GetClientByClientID mocks base method -func (m *MockStorage) GetClientByClientID(arg0 string) (op.Client, error) { +func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClientByClientID", arg0) + ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1) ret0, _ := ret[0].(op.Client) ret1, _ := ret[1].(error) return ret0, ret1 } // GetClientByClientID indicates an expected call of GetClientByClientID -func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1) } // GetKeySet mocks base method -func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) { +func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKeySet") + ret := m.ctrl.Call(m, "GetKeySet", arg0) ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) ret1, _ := ret[1].(error) return ret0, ret1 } // GetKeySet indicates an expected call of GetKeySet -func (mr *MockStorageMockRecorder) GetKeySet() *gomock.Call { +func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0) } // GetSigningKey mocks base method -func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) { +func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSigningKey") + ret := m.ctrl.Call(m, "GetSigningKey", arg0) ret0, _ := ret[0].(*go_jose_v2.SigningKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetSigningKey indicates an expected call of GetSigningKey -func (mr *MockStorageMockRecorder) GetSigningKey() *gomock.Call { +func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0) } // GetUserinfoFromScopes mocks base method -func (m *MockStorage) GetUserinfoFromScopes(arg0 []string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0) + ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1) ret0, _ := ret[0].(*oidc.Userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes -func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1) } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 9692211..280fd63 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -3,6 +3,7 @@ package op import ( "encoding/json" + "golang.org/x/net/context" "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/oidc" @@ -19,18 +20,18 @@ type idTokenSigner struct { algorithm jose.SignatureAlgorithm } -func NewDefaultSigner(storage AuthStorage) (Signer, error) { +func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) { s := &idTokenSigner{ storage: storage, } - if err := s.initialize(); err != nil { + if err := s.initialize(ctx); err != nil { return nil, err } return s, nil } -func (s *idTokenSigner) initialize() error { - key, err := s.storage.GetSigningKey() +func (s *idTokenSigner) initialize(ctx context.Context) error { + key, err := s.storage.GetSigningKey(ctx) if err != nil { return err } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 5c23d6a..81c9c58 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -1,6 +1,7 @@ package op import ( + "context" "time" "gopkg.in/square/go-jose.v2" @@ -9,18 +10,18 @@ import ( ) type AuthStorage interface { - CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) - AuthRequestByID(string) (AuthRequest, error) - DeleteAuthRequest(string) error + CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error) + AuthRequestByID(context.Context, string) (AuthRequest, error) + DeleteAuthRequest(context.Context, string) error - GetSigningKey() (*jose.SigningKey, error) - GetKeySet() (*jose.JSONWebKeySet, error) + GetSigningKey(context.Context) (*jose.SigningKey, error) + GetKeySet(context.Context) (*jose.JSONWebKeySet, error) } type OPStorage interface { - GetClientByClientID(string) (Client, error) - AuthorizeClientIDSecret(string, string) error - GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) + GetClientByClientID(context.Context, string) (Client, error) + AuthorizeClientIDSecret(context.Context, string, string) error + GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) } type Storage interface { @@ -43,4 +44,5 @@ type AuthRequest interface { GetScopes() []string GetState() string GetSubject() string + Done() bool } diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 935589f..8a0bc4f 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -1,6 +1,7 @@ package op import ( + "context" "errors" "net/http" "time" @@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } - authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger) + authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { ExchangeRequestError(w, r, err) return } - err = exchanger.Storage().DeleteAuthRequest(authReq.GetID()) + err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID()) if err != nil { ExchangeRequestError(w, r, err) return @@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac return tokenReq, nil } -func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { - authReq, client, err := AuthorizeClient(tokenReq, exchanger) +func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { + authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger) if err != nil { return nil, err } @@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc return authReq, nil } -func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { - client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID) +func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { + client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) if err != nil { return nil, nil, err } - switch client.GetAuthMethod() { - case AuthMethodNone: - authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage()) + if client.GetAuthMethod() == AuthMethodNone { + authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage()) return authReq, client, err - case AuthMethodPost: - if !exchanger.AuthMethodPostSupported() { - return nil, nil, errors.New("basic not supported") - } - err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) - case AuthMethodBasic: - err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) - default: - err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) } + if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() { + return nil, nil, errors.New("basic not supported") + } + err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) if err != nil { return nil, nil, err } - authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) + authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) if err != nil { return nil, nil, ErrInvalidRequest("invalid code") } return authReq, client, nil } -func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error { - return storage.AuthorizeClientIDSecret(clientID, clientSecret) +func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error { + return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) } -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) { +func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) { if tokenReq.CodeVerifier == "" { return nil, ErrInvalidRequest("code_challenge required") } - authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage) + authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage) if err != nil { return nil, ErrInvalidRequest("invalid code") } @@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora return authReq, nil } -func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { +func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { id, err := crypto.Decrypt(code) if err != nil { return nil, err } - return storage.AuthRequestByID(id) + return storage.AuthRequestByID(ctx, id) } func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index ad81f69..ac47e68 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -15,7 +15,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP if err != nil { return } - info, err := userinfoProvider.Storage().GetUserinfoFromScopes(scopes) + info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes) if err != nil { utils.MarshalJSON(w, err) return