add context
This commit is contained in:
parent
0731a62833
commit
462b5c83cd
12 changed files with 104 additions and 98 deletions
|
@ -1,6 +1,7 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
|
@ -107,7 +108,7 @@ var (
|
|||
t bool
|
||||
)
|
||||
|
||||
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
||||
if authReq.CodeChallenge != "" {
|
||||
a.CodeChallenge = &oidc.CodeChallenge{
|
||||
|
@ -118,26 +119,26 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque
|
|||
t = false
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
|
||||
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) DeleteAuthRequest(string) error {
|
||||
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
||||
t = true
|
||||
return nil
|
||||
}
|
||||
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
||||
func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
|
||||
if id != "id" || t {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
||||
func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) {
|
||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||
}
|
||||
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
||||
func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
||||
func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
|
||||
pubkey := s.key.Public()
|
||||
return &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
|
@ -146,7 +147,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
|
||||
func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
|
||||
if id == "none" {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
@ -165,11 +166,11 @@ func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
|
|||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error {
|
||||
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
|
||||
func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) {
|
||||
return &oidc.Userinfo{
|
||||
Subject: a.GetSubject(),
|
||||
Address: &oidc.UserinfoAddress{
|
||||
|
|
|
@ -21,7 +21,7 @@ func main() {
|
|||
Port: "9998",
|
||||
}
|
||||
storage := mock.NewAuthStorage()
|
||||
handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test"))
|
||||
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -14,7 +14,7 @@ require (
|
|||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
|
||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 // indirect
|
||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
|
||||
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
|
||||
golang.org/x/text v0.3.2
|
||||
|
|
|
@ -2,7 +2,6 @@ package oidc
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
@ -33,9 +32,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
|||
return err
|
||||
}
|
||||
audience := i.Audiences
|
||||
if len(audience) == 1 {
|
||||
audience = strings.Split(audience[0], " ")
|
||||
}
|
||||
// if len(audience) == 1 {
|
||||
// audience = strings.Split(audience[0], " ")
|
||||
// }
|
||||
t.Issuer = i.Issuer
|
||||
t.Subject = i.Subject
|
||||
t.Audiences = audience
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -24,7 +26,7 @@ type Authorizer interface {
|
|||
|
||||
type ValidationAuthorizer interface {
|
||||
Authorizer
|
||||
ValidateAuthRequest(*oidc.AuthRequest, Storage) error
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
|
||||
}
|
||||
|
||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
|
@ -45,18 +47,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||
validation = validater.ValidateAuthRequest
|
||||
}
|
||||
if err := validation(authReq, authorizer.Storage()); err != nil {
|
||||
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||
return
|
||||
|
@ -64,11 +66,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
}
|
||||
|
||||
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error {
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
|
||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
||||
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
|
||||
|
@ -93,11 +95,11 @@ func ValidateAuthReqScopes(scopes []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||
if uri == "" {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
|
||||
}
|
||||
client, err := storage.GetClientByClientID(client_id)
|
||||
client, err := storage.GetClientByClientID(ctx, client_id)
|
||||
if err != nil {
|
||||
return ErrServerError(err.Error())
|
||||
}
|
||||
|
@ -142,11 +144,15 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
params := mux.Vars(r)
|
||||
id := params["id"]
|
||||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(id)
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
if !authReq.Done() {
|
||||
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -101,7 +102,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
|
|||
}
|
||||
}
|
||||
|
||||
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
||||
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
||||
err := ValidateIssuer(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -113,7 +114,7 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope
|
|||
endpoints: DefaultEndpoints,
|
||||
}
|
||||
|
||||
p.signer, err = NewDefaultSigner(storage)
|
||||
p.signer, err = NewDefaultSigner(ctx, storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ type KeyProvider interface {
|
|||
}
|
||||
|
||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||
keySet, err := k.Storage().GetKeySet()
|
||||
keySet, err := k.Storage().GetKeySet(r.Context())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
@ -36,119 +37,119 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
|||
}
|
||||
|
||||
// AuthRequestByID mocks base method
|
||||
func (m *MockStorage) AuthRequestByID(arg0 string) (op.AuthRequest, error) {
|
||||
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthRequestByID", arg0)
|
||||
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AuthRequestByID indicates an expected call of AuthRequestByID
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret mocks base method
|
||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
|
||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// CreateAuthRequest mocks base method
|
||||
func (m *MockStorage) CreateAuthRequest(arg0 *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0)
|
||||
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateAuthRequest indicates an expected call of CreateAuthRequest
|
||||
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteAuthRequest mocks base method
|
||||
func (m *MockStorage) DeleteAuthRequest(arg0 string) error {
|
||||
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0)
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
|
||||
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetClientByClientID mocks base method
|
||||
func (m *MockStorage) GetClientByClientID(arg0 string) (op.Client, error) {
|
||||
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClientByClientID", arg0)
|
||||
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.Client)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetClientByClientID indicates an expected call of GetClientByClientID
|
||||
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetKeySet mocks base method
|
||||
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet")
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeySet indicates an expected call of GetKeySet
|
||||
func (mr *MockStorageMockRecorder) GetKeySet() *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||
}
|
||||
|
||||
// GetSigningKey mocks base method
|
||||
func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) {
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSigningKey")
|
||||
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSigningKey indicates an expected call of GetSigningKey
|
||||
func (mr *MockStorageMockRecorder) GetSigningKey() *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes mocks base method
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 []string) (*oidc.Userinfo, error) {
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0)
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package op
|
|||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
|
@ -19,18 +20,18 @@ type idTokenSigner struct {
|
|||
algorithm jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
func NewDefaultSigner(storage AuthStorage) (Signer, error) {
|
||||
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
|
||||
s := &idTokenSigner{
|
||||
storage: storage,
|
||||
}
|
||||
if err := s.initialize(); err != nil {
|
||||
if err := s.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) initialize() error {
|
||||
key, err := s.storage.GetSigningKey()
|
||||
func (s *idTokenSigner) initialize(ctx context.Context) error {
|
||||
key, err := s.storage.GetSigningKey(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
@ -9,18 +10,18 @@ import (
|
|||
)
|
||||
|
||||
type AuthStorage interface {
|
||||
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
||||
AuthRequestByID(string) (AuthRequest, error)
|
||||
DeleteAuthRequest(string) error
|
||||
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
|
||||
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||
DeleteAuthRequest(context.Context, string) error
|
||||
|
||||
GetSigningKey() (*jose.SigningKey, error)
|
||||
GetKeySet() (*jose.JSONWebKeySet, error)
|
||||
GetSigningKey(context.Context) (*jose.SigningKey, error)
|
||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||
}
|
||||
|
||||
type OPStorage interface {
|
||||
GetClientByClientID(string) (Client, error)
|
||||
AuthorizeClientIDSecret(string, string) error
|
||||
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
|
||||
GetClientByClientID(context.Context, string) (Client, error)
|
||||
AuthorizeClientIDSecret(context.Context, string, string) error
|
||||
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
|
@ -43,4 +44,5 @@ type AuthRequest interface {
|
|||
GetScopes() []string
|
||||
GetState() string
|
||||
GetSubject() string
|
||||
Done() bool
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
return
|
||||
}
|
||||
|
||||
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
|
||||
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
|
|||
return tokenReq, nil
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
|
|||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
switch client.GetAuthMethod() {
|
||||
case AuthMethodNone:
|
||||
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
|
||||
if client.GetAuthMethod() == AuthMethodNone {
|
||||
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
|
||||
return authReq, client, err
|
||||
case AuthMethodPost:
|
||||
if !exchanger.AuthMethodPostSupported() {
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||
return nil, nil, errors.New("basic not supported")
|
||||
}
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
case AuthMethodBasic:
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
default:
|
||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
}
|
||||
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(clientID, clientSecret)
|
||||
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
|
||||
}
|
||||
|
||||
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
||||
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
||||
if tokenReq.CodeVerifier == "" {
|
||||
return nil, ErrInvalidRequest("code_challenge required")
|
||||
}
|
||||
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
|
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
|
|||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
id, err := crypto.Decrypt(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.AuthRequestByID(id)
|
||||
return storage.AuthRequestByID(ctx, id)
|
||||
}
|
||||
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
|
|
|
@ -15,7 +15,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(scopes)
|
||||
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
|
||||
if err != nil {
|
||||
utils.MarshalJSON(w, err)
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue