add context

This commit is contained in:
Livio Amstutz 2019-12-18 16:05:21 +01:00
parent 0731a62833
commit 462b5c83cd
12 changed files with 104 additions and 98 deletions

View file

@ -1,6 +1,7 @@
package mock
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
@ -107,7 +108,7 @@ var (
t bool
)
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) {
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
if authReq.CodeChallenge != "" {
a.CodeChallenge = &oidc.CodeChallenge{
@ -118,26 +119,26 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque
t = false
return a, nil
}
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
return a, nil
}
func (s *AuthStorage) DeleteAuthRequest(string) error {
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
t = true
return nil
}
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
if id != "id" || t {
return nil, errors.New("not found")
}
return a, nil
}
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
}
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
return s.key, nil
}
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
pubkey := s.key.Public()
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
@ -146,7 +147,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
}, nil
}
func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
if id == "none" {
return nil, errors.New("not found")
}
@ -165,11 +166,11 @@ func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
}
func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error {
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
return nil
}
func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{
Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{

View file

@ -21,7 +21,7 @@ func main() {
Port: "9998",
}
storage := mock.NewAuthStorage()
handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test"))
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test"))
if err != nil {
log.Fatal(err)
}

2
go.mod
View file

@ -14,7 +14,7 @@ require (
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 // indirect
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
golang.org/x/text v0.3.2

View file

@ -2,7 +2,6 @@ package oidc
import (
"encoding/json"
"strings"
"time"
"github.com/caos/oidc/pkg/utils"
@ -33,9 +32,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
return err
}
audience := i.Audiences
if len(audience) == 1 {
audience = strings.Split(audience[0], " ")
}
// if len(audience) == 1 {
// audience = strings.Split(audience[0], " ")
// }
t.Issuer = i.Issuer
t.Subject = i.Subject
t.Audiences = audience

View file

@ -1,6 +1,8 @@
package op
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
@ -24,7 +26,7 @@ type Authorizer interface {
type ValidationAuthorizer interface {
Authorizer
ValidateAuthRequest(*oidc.AuthRequest, Storage) error
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
}
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
@ -45,18 +47,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if validater, ok := authorizer.(ValidationAuthorizer); ok {
validation = validater.ValidateAuthRequest
}
if err := validation(authReq, authorizer.Storage()); err != nil {
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(authReq)
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder())
return
@ -64,11 +66,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
RedirectToLogin(req.GetID(), client, w, r)
}
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error {
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
return err
}
if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
return err
}
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
@ -93,11 +95,11 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil
}
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" {
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
}
client, err := storage.GetClientByClientID(client_id)
client, err := storage.GetClientByClientID(ctx, client_id)
if err != nil {
return ErrServerError(err.Error())
}
@ -142,11 +144,15 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
params := mux.Vars(r)
id := params["id"]
authReq, err := authorizer.Storage().AuthRequestByID(id)
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
if !authReq.Done() {
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r)
}

View file

@ -1,6 +1,7 @@
package op
import (
"context"
"net/http"
"time"
@ -101,7 +102,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
}
}
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer)
if err != nil {
return nil, err
@ -113,7 +114,7 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope
endpoints: DefaultEndpoints,
}
p.signer, err = NewDefaultSigner(storage)
p.signer, err = NewDefaultSigner(ctx, storage)
if err != nil {
return nil, err
}

View file

@ -11,7 +11,7 @@ type KeyProvider interface {
}
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.Storage().GetKeySet()
keySet, err := k.Storage().GetKeySet(r.Context())
if err != nil {
}

View file

@ -5,6 +5,7 @@
package mock
import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc"
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
@ -36,119 +37,119 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
}
// AuthRequestByID mocks base method
func (m *MockStorage) AuthRequestByID(arg0 string) (op.AuthRequest, error) {
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByID", arg0)
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByID indicates an expected call of AuthRequestByID
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
}
// AuthorizeClientIDSecret mocks base method
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
}
// CreateAuthRequest mocks base method
func (m *MockStorage) CreateAuthRequest(arg0 *oidc.AuthRequest) (op.AuthRequest, error) {
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0)
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateAuthRequest indicates an expected call of CreateAuthRequest
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
}
// DeleteAuthRequest mocks base method
func (m *MockStorage) DeleteAuthRequest(arg0 string) error {
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0)
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
}
// GetClientByClientID mocks base method
func (m *MockStorage) GetClientByClientID(arg0 string) (op.Client, error) {
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClientByClientID", arg0)
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetClientByClientID indicates an expected call of GetClientByClientID
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
}
// GetKeySet mocks base method
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet")
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet
func (mr *MockStorageMockRecorder) GetKeySet() *gomock.Call {
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
}
// GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) {
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSigningKey")
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSigningKey indicates an expected call of GetSigningKey
func (mr *MockStorageMockRecorder) GetSigningKey() *gomock.Call {
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
}
// GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 []string) (*oidc.Userinfo, error) {
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0)
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
ret0, _ := ret[0].(*oidc.Userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
}

View file

@ -3,6 +3,7 @@ package op
import (
"encoding/json"
"golang.org/x/net/context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
@ -19,18 +20,18 @@ type idTokenSigner struct {
algorithm jose.SignatureAlgorithm
}
func NewDefaultSigner(storage AuthStorage) (Signer, error) {
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
s := &idTokenSigner{
storage: storage,
}
if err := s.initialize(); err != nil {
if err := s.initialize(ctx); err != nil {
return nil, err
}
return s, nil
}
func (s *idTokenSigner) initialize() error {
key, err := s.storage.GetSigningKey()
func (s *idTokenSigner) initialize(ctx context.Context) error {
key, err := s.storage.GetSigningKey(ctx)
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package op
import (
"context"
"time"
"gopkg.in/square/go-jose.v2"
@ -9,18 +10,18 @@ import (
)
type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error)
DeleteAuthRequest(string) error
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(context.Context, string) (AuthRequest, error)
DeleteAuthRequest(context.Context, string) error
GetSigningKey() (*jose.SigningKey, error)
GetKeySet() (*jose.JSONWebKeySet, error)
GetSigningKey(context.Context) (*jose.SigningKey, error)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
}
type OPStorage interface {
GetClientByClientID(string) (Client, error)
AuthorizeClientIDSecret(string, string) error
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
}
type Storage interface {
@ -43,4 +44,5 @@ type AuthRequest interface {
GetScopes() []string
GetState() string
GetSubject() string
Done() bool
}

View file

@ -1,6 +1,7 @@
package op
import (
"context"
"errors"
"net/http"
"time"
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
return tokenReq, nil
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
if err != nil {
return nil, err
}
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
return authReq, nil
}
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil {
return nil, nil, err
}
switch client.GetAuthMethod() {
case AuthMethodNone:
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
if client.GetAuthMethod() == AuthMethodNone {
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
return authReq, client, err
case AuthMethodPost:
if !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
case AuthMethodBasic:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
default:
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
}
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, err
}
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil {
return nil, nil, ErrInvalidRequest("invalid code")
}
return authReq, client, nil
}
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(clientID, clientSecret)
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
}
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
return authReq, nil
}
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(id)
return storage.AuthRequestByID(ctx, id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {

View file

@ -15,7 +15,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
if err != nil {
return
}
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(scopes)
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
if err != nil {
utils.MarshalJSON(w, err)
return