add context

This commit is contained in:
Livio Amstutz 2019-12-18 16:05:21 +01:00
parent 0731a62833
commit 462b5c83cd
12 changed files with 104 additions and 98 deletions

View file

@ -1,6 +1,7 @@
package mock package mock
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"errors" "errors"
@ -107,7 +108,7 @@ var (
t bool t bool
) )
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) {
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
if authReq.CodeChallenge != "" { if authReq.CodeChallenge != "" {
a.CodeChallenge = &oidc.CodeChallenge{ a.CodeChallenge = &oidc.CodeChallenge{
@ -118,26 +119,26 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque
t = false t = false
return a, nil return a, nil
} }
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
return a, nil return a, nil
} }
func (s *AuthStorage) DeleteAuthRequest(string) error { func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
t = true t = true
return nil return nil
} }
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
if id != "id" || t { if id != "id" || t {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
return a, nil return a, nil
} }
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
} }
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) { func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
return s.key, nil return s.key, nil
} }
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
pubkey := s.key.Public() pubkey := s.key.Public()
return &jose.JSONWebKeySet{ return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{ Keys: []jose.JSONWebKey{
@ -146,7 +147,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
}, nil }, nil
} }
func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) { func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
if id == "none" { if id == "none" {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
@ -165,11 +166,11 @@ func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
} }
func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error { func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
return nil return nil
} }
func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) { func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{ return &oidc.Userinfo{
Subject: a.GetSubject(), Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{ Address: &oidc.UserinfoAddress{

View file

@ -21,7 +21,7 @@ func main() {
Port: "9998", Port: "9998",
} }
storage := mock.NewAuthStorage() storage := mock.NewAuthStorage()
handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test")) handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test"))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

2
go.mod
View file

@ -14,7 +14,7 @@ require (
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 // indirect golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
golang.org/x/text v0.3.2 golang.org/x/text v0.3.2

View file

@ -2,7 +2,6 @@ package oidc
import ( import (
"encoding/json" "encoding/json"
"strings"
"time" "time"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
@ -33,9 +32,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
return err return err
} }
audience := i.Audiences audience := i.Audiences
if len(audience) == 1 { // if len(audience) == 1 {
audience = strings.Split(audience[0], " ") // audience = strings.Split(audience[0], " ")
} // }
t.Issuer = i.Issuer t.Issuer = i.Issuer
t.Subject = i.Subject t.Subject = i.Subject
t.Audiences = audience t.Audiences = audience

View file

@ -1,6 +1,8 @@
package op package op
import ( import (
"context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -24,7 +26,7 @@ type Authorizer interface {
type ValidationAuthorizer interface { type ValidationAuthorizer interface {
Authorizer Authorizer
ValidateAuthRequest(*oidc.AuthRequest, Storage) error ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
} }
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
@ -45,18 +47,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if validater, ok := authorizer.(ValidationAuthorizer); ok { if validater, ok := authorizer.(ValidationAuthorizer); ok {
validation = validater.ValidateAuthRequest validation = validater.ValidateAuthRequest
} }
if err := validation(authReq, authorizer.Storage()); err != nil { if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
req, err := authorizer.Storage().CreateAuthRequest(authReq) req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder()) AuthRequestError(w, r, req, err, authorizer.Encoder())
return return
@ -64,11 +66,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
} }
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
return err return err
} }
if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
return err return err
} }
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil { if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
@ -93,11 +95,11 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil return nil
} }
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" { if uri == "" {
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty") return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
} }
client, err := storage.GetClientByClientID(client_id) client, err := storage.GetClientByClientID(ctx, client_id)
if err != nil { if err != nil {
return ErrServerError(err.Error()) return ErrServerError(err.Error())
} }
@ -142,11 +144,15 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
params := mux.Vars(r) params := mux.Vars(r)
id := params["id"] id := params["id"]
authReq, err := authorizer.Storage().AuthRequestByID(id) authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil { if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder()) AuthRequestError(w, r, nil, err, authorizer.Encoder())
return return
} }
if !authReq.Done() {
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
} }

View file

@ -1,6 +1,7 @@
package op package op
import ( import (
"context"
"net/http" "net/http"
"time" "time"
@ -101,7 +102,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
} }
} }
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer) err := ValidateIssuer(config.Issuer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -113,7 +114,7 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope
endpoints: DefaultEndpoints, endpoints: DefaultEndpoints,
} }
p.signer, err = NewDefaultSigner(storage) p.signer, err = NewDefaultSigner(ctx, storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,7 +11,7 @@ type KeyProvider interface {
} }
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.Storage().GetKeySet() keySet, err := k.Storage().GetKeySet(r.Context())
if err != nil { if err != nil {
} }

View file

@ -5,6 +5,7 @@
package mock package mock
import ( import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc" oidc "github.com/caos/oidc/pkg/oidc"
op "github.com/caos/oidc/pkg/op" op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -36,119 +37,119 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
} }
// AuthRequestByID mocks base method // AuthRequestByID mocks base method
func (m *MockStorage) AuthRequestByID(arg0 string) (op.AuthRequest, error) { func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByID", arg0) ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest) ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// AuthRequestByID indicates an expected call of AuthRequestByID // AuthRequestByID indicates an expected call of AuthRequestByID
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
} }
// AuthorizeClientIDSecret mocks base method // AuthorizeClientIDSecret mocks base method
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error { func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1) ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret // AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
} }
// CreateAuthRequest mocks base method // CreateAuthRequest mocks base method
func (m *MockStorage) CreateAuthRequest(arg0 *oidc.AuthRequest) (op.AuthRequest, error) { func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0) ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest) ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// CreateAuthRequest indicates an expected call of CreateAuthRequest // CreateAuthRequest indicates an expected call of CreateAuthRequest
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
} }
// DeleteAuthRequest mocks base method // DeleteAuthRequest mocks base method
func (m *MockStorage) DeleteAuthRequest(arg0 string) error { func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0) ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest // DeleteAuthRequest indicates an expected call of DeleteAuthRequest
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
} }
// GetClientByClientID mocks base method // GetClientByClientID mocks base method
func (m *MockStorage) GetClientByClientID(arg0 string) (op.Client, error) { func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClientByClientID", arg0) ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
ret0, _ := ret[0].(op.Client) ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetClientByClientID indicates an expected call of GetClientByClientID // GetClientByClientID indicates an expected call of GetClientByClientID
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
} }
// GetKeySet mocks base method // GetKeySet mocks base method
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) { func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet") ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetKeySet indicates an expected call of GetKeySet // GetKeySet indicates an expected call of GetKeySet
func (mr *MockStorageMockRecorder) GetKeySet() *gomock.Call { func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
} }
// GetSigningKey mocks base method // GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) { func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSigningKey") ret := m.ctrl.Call(m, "GetSigningKey", arg0)
ret0, _ := ret[0].(*go_jose_v2.SigningKey) ret0, _ := ret[0].(*go_jose_v2.SigningKey)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetSigningKey indicates an expected call of GetSigningKey // GetSigningKey indicates an expected call of GetSigningKey
func (mr *MockStorageMockRecorder) GetSigningKey() *gomock.Call { func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
} }
// GetUserinfoFromScopes mocks base method // GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 []string) (*oidc.Userinfo, error) { func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0) ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
ret0, _ := ret[0].(*oidc.Userinfo) ret0, _ := ret[0].(*oidc.Userinfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes // GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0 interface{}) *gomock.Call { func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
} }

View file

@ -3,6 +3,7 @@ package op
import ( import (
"encoding/json" "encoding/json"
"golang.org/x/net/context"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
@ -19,18 +20,18 @@ type idTokenSigner struct {
algorithm jose.SignatureAlgorithm algorithm jose.SignatureAlgorithm
} }
func NewDefaultSigner(storage AuthStorage) (Signer, error) { func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
s := &idTokenSigner{ s := &idTokenSigner{
storage: storage, storage: storage,
} }
if err := s.initialize(); err != nil { if err := s.initialize(ctx); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
} }
func (s *idTokenSigner) initialize() error { func (s *idTokenSigner) initialize(ctx context.Context) error {
key, err := s.storage.GetSigningKey() key, err := s.storage.GetSigningKey(ctx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,6 +1,7 @@
package op package op
import ( import (
"context"
"time" "time"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -9,18 +10,18 @@ import (
) )
type AuthStorage interface { type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error) AuthRequestByID(context.Context, string) (AuthRequest, error)
DeleteAuthRequest(string) error DeleteAuthRequest(context.Context, string) error
GetSigningKey() (*jose.SigningKey, error) GetSigningKey(context.Context) (*jose.SigningKey, error)
GetKeySet() (*jose.JSONWebKeySet, error) GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
} }
type OPStorage interface { type OPStorage interface {
GetClientByClientID(string) (Client, error) GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(string, string) error AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
} }
type Storage interface { type Storage interface {
@ -43,4 +44,5 @@ type AuthRequest interface {
GetScopes() []string GetScopes() []string
GetState() string GetState() string
GetSubject() string GetSubject() string
Done() bool
} }

View file

@ -1,6 +1,7 @@
package op package op
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"time" "time"
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return return
} }
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger) authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID()) err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
return tokenReq, nil return tokenReq, nil
} }
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(tokenReq, exchanger) authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
return authReq, nil return authReq, nil
} }
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID) client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
switch client.GetAuthMethod() { if client.GetAuthMethod() == AuthMethodNone {
case AuthMethodNone: authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
return authReq, client, err return authReq, client, err
case AuthMethodPost: }
if !exchanger.AuthMethodPostSupported() { if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported") return nil, nil, errors.New("basic not supported")
} }
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
case AuthMethodBasic:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
default:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil { if err != nil {
return nil, nil, ErrInvalidRequest("invalid code") return nil, nil, ErrInvalidRequest("invalid code")
} }
return authReq, client, nil return authReq, client, nil
} }
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error { func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(clientID, clientSecret) return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
} }
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) { func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" { if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required") return nil, ErrInvalidRequest("code_challenge required")
} }
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage) authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid code") return nil, ErrInvalidRequest("invalid code")
} }
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
return authReq, nil return authReq, nil
} }
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code) id, err := crypto.Decrypt(code)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return storage.AuthRequestByID(id) return storage.AuthRequestByID(ctx, id)
} }
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {

View file

@ -15,7 +15,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
if err != nil { if err != nil {
return return
} }
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(scopes) info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
if err != nil { if err != nil {
utils.MarshalJSON(w, err) utils.MarshalJSON(w, err)
return return