add context
This commit is contained in:
parent
0731a62833
commit
462b5c83cd
12 changed files with 104 additions and 98 deletions
|
@ -1,6 +1,7 @@
|
||||||
package mock
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -107,7 +108,7 @@ var (
|
||||||
t bool
|
t bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||||
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
||||||
if authReq.CodeChallenge != "" {
|
if authReq.CodeChallenge != "" {
|
||||||
a.CodeChallenge = &oidc.CodeChallenge{
|
a.CodeChallenge = &oidc.CodeChallenge{
|
||||||
|
@ -118,26 +119,26 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque
|
||||||
t = false
|
t = false
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
|
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) DeleteAuthRequest(string) error {
|
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
||||||
t = true
|
t = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
|
||||||
if id != "id" || t {
|
if id != "id" || t {
|
||||||
return nil, errors.New("not found")
|
return nil, errors.New("not found")
|
||||||
}
|
}
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) {
|
||||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
|
||||||
return s.key, nil
|
return s.key, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
|
||||||
pubkey := s.key.Public()
|
pubkey := s.key.Public()
|
||||||
return &jose.JSONWebKeySet{
|
return &jose.JSONWebKeySet{
|
||||||
Keys: []jose.JSONWebKey{
|
Keys: []jose.JSONWebKey{
|
||||||
|
@ -146,7 +147,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
|
func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
|
||||||
if id == "none" {
|
if id == "none" {
|
||||||
return nil, errors.New("not found")
|
return nil, errors.New("not found")
|
||||||
}
|
}
|
||||||
|
@ -165,11 +166,11 @@ func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
|
||||||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error {
|
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
|
func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) {
|
||||||
return &oidc.Userinfo{
|
return &oidc.Userinfo{
|
||||||
Subject: a.GetSubject(),
|
Subject: a.GetSubject(),
|
||||||
Address: &oidc.UserinfoAddress{
|
Address: &oidc.UserinfoAddress{
|
||||||
|
|
|
@ -21,7 +21,7 @@ func main() {
|
||||||
Port: "9998",
|
Port: "9998",
|
||||||
}
|
}
|
||||||
storage := mock.NewAuthStorage()
|
storage := mock.NewAuthStorage()
|
||||||
handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test"))
|
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -14,7 +14,7 @@ require (
|
||||||
github.com/sirupsen/logrus v1.4.2
|
github.com/sirupsen/logrus v1.4.2
|
||||||
github.com/stretchr/testify v1.4.0
|
github.com/stretchr/testify v1.4.0
|
||||||
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
|
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
|
||||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 // indirect
|
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
|
||||||
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
|
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
|
||||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
|
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
|
||||||
golang.org/x/text v0.3.2
|
golang.org/x/text v0.3.2
|
||||||
|
|
|
@ -2,7 +2,6 @@ package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/utils"
|
"github.com/caos/oidc/pkg/utils"
|
||||||
|
@ -33,9 +32,9 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
audience := i.Audiences
|
audience := i.Audiences
|
||||||
if len(audience) == 1 {
|
// if len(audience) == 1 {
|
||||||
audience = strings.Split(audience[0], " ")
|
// audience = strings.Split(audience[0], " ")
|
||||||
}
|
// }
|
||||||
t.Issuer = i.Issuer
|
t.Issuer = i.Issuer
|
||||||
t.Subject = i.Subject
|
t.Subject = i.Subject
|
||||||
t.Audiences = audience
|
t.Audiences = audience
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -24,7 +26,7 @@ type Authorizer interface {
|
||||||
|
|
||||||
type ValidationAuthorizer interface {
|
type ValidationAuthorizer interface {
|
||||||
Authorizer
|
Authorizer
|
||||||
ValidateAuthRequest(*oidc.AuthRequest, Storage) error
|
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
|
@ -45,18 +47,18 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||||
validation = validater.ValidateAuthRequest
|
validation = validater.ValidateAuthRequest
|
||||||
}
|
}
|
||||||
if err := validation(authReq, authorizer.Storage()); err != nil {
|
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
|
@ -64,11 +66,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
RedirectToLogin(req.GetID(), client, w, r)
|
RedirectToLogin(req.GetID(), client, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAuthRequest(authReq *oidc.AuthRequest, storage Storage) error {
|
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
|
||||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ValidateAuthReqRedirectURI(authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
|
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
|
||||||
|
@ -93,11 +95,11 @@ func ValidateAuthReqScopes(scopes []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||||
if uri == "" {
|
if uri == "" {
|
||||||
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
|
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
|
||||||
}
|
}
|
||||||
client, err := storage.GetClientByClientID(client_id)
|
client, err := storage.GetClientByClientID(ctx, client_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrServerError(err.Error())
|
return ErrServerError(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -142,11 +144,15 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
||||||
params := mux.Vars(r)
|
params := mux.Vars(r)
|
||||||
id := params["id"]
|
id := params["id"]
|
||||||
|
|
||||||
authReq, err := authorizer.Storage().AuthRequestByID(id)
|
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !authReq.Done() {
|
||||||
|
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
|
||||||
|
return
|
||||||
|
}
|
||||||
AuthResponse(authReq, authorizer, w, r)
|
AuthResponse(authReq, authorizer, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
||||||
err := ValidateIssuer(config.Issuer)
|
err := ValidateIssuer(config.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -113,7 +114,7 @@ func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (Ope
|
||||||
endpoints: DefaultEndpoints,
|
endpoints: DefaultEndpoints,
|
||||||
}
|
}
|
||||||
|
|
||||||
p.signer, err = NewDefaultSigner(storage)
|
p.signer, err = NewDefaultSigner(ctx, storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ type KeyProvider interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||||
keySet, err := k.Storage().GetKeySet()
|
keySet, err := k.Storage().GetKeySet(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package mock
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context "context"
|
||||||
oidc "github.com/caos/oidc/pkg/oidc"
|
oidc "github.com/caos/oidc/pkg/oidc"
|
||||||
op "github.com/caos/oidc/pkg/op"
|
op "github.com/caos/oidc/pkg/op"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
@ -36,119 +37,119 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequestByID mocks base method
|
// AuthRequestByID mocks base method
|
||||||
func (m *MockStorage) AuthRequestByID(arg0 string) (op.AuthRequest, error) {
|
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AuthRequestByID", arg0)
|
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
|
||||||
ret0, _ := ret[0].(op.AuthRequest)
|
ret0, _ := ret[0].(op.AuthRequest)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequestByID indicates an expected call of AuthRequestByID
|
// AuthRequestByID indicates an expected call of AuthRequestByID
|
||||||
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeClientIDSecret mocks base method
|
// AuthorizeClientIDSecret mocks base method
|
||||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0, arg1 string) error {
|
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1)
|
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
||||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAuthRequest mocks base method
|
// CreateAuthRequest mocks base method
|
||||||
func (m *MockStorage) CreateAuthRequest(arg0 *oidc.AuthRequest) (op.AuthRequest, error) {
|
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0)
|
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
|
||||||
ret0, _ := ret[0].(op.AuthRequest)
|
ret0, _ := ret[0].(op.AuthRequest)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAuthRequest indicates an expected call of CreateAuthRequest
|
// CreateAuthRequest indicates an expected call of CreateAuthRequest
|
||||||
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAuthRequest mocks base method
|
// DeleteAuthRequest mocks base method
|
||||||
func (m *MockStorage) DeleteAuthRequest(arg0 string) error {
|
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0)
|
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
|
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
|
||||||
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientByClientID mocks base method
|
// GetClientByClientID mocks base method
|
||||||
func (m *MockStorage) GetClientByClientID(arg0 string) (op.Client, error) {
|
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetClientByClientID", arg0)
|
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
|
||||||
ret0, _ := ret[0].(op.Client)
|
ret0, _ := ret[0].(op.Client)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientByClientID indicates an expected call of GetClientByClientID
|
// GetClientByClientID indicates an expected call of GetClientByClientID
|
||||||
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKeySet mocks base method
|
// GetKeySet mocks base method
|
||||||
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
|
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetKeySet")
|
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKeySet indicates an expected call of GetKeySet
|
// GetKeySet indicates an expected call of GetKeySet
|
||||||
func (mr *MockStorageMockRecorder) GetKeySet() *gomock.Call {
|
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSigningKey mocks base method
|
// GetSigningKey mocks base method
|
||||||
func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) {
|
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetSigningKey")
|
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
|
||||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSigningKey indicates an expected call of GetSigningKey
|
// GetSigningKey indicates an expected call of GetSigningKey
|
||||||
func (mr *MockStorageMockRecorder) GetSigningKey() *gomock.Call {
|
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserinfoFromScopes mocks base method
|
// GetUserinfoFromScopes mocks base method
|
||||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 []string) (*oidc.Userinfo, error) {
|
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0)
|
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
|
||||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
ret0, _ := ret[0].(*oidc.Userinfo)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package op
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
@ -19,18 +20,18 @@ type idTokenSigner struct {
|
||||||
algorithm jose.SignatureAlgorithm
|
algorithm jose.SignatureAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultSigner(storage AuthStorage) (Signer, error) {
|
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
|
||||||
s := &idTokenSigner{
|
s := &idTokenSigner{
|
||||||
storage: storage,
|
storage: storage,
|
||||||
}
|
}
|
||||||
if err := s.initialize(); err != nil {
|
if err := s.initialize(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *idTokenSigner) initialize() error {
|
func (s *idTokenSigner) initialize(ctx context.Context) error {
|
||||||
key, err := s.storage.GetSigningKey()
|
key, err := s.storage.GetSigningKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
@ -9,18 +10,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthStorage interface {
|
type AuthStorage interface {
|
||||||
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
|
||||||
AuthRequestByID(string) (AuthRequest, error)
|
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||||
DeleteAuthRequest(string) error
|
DeleteAuthRequest(context.Context, string) error
|
||||||
|
|
||||||
GetSigningKey() (*jose.SigningKey, error)
|
GetSigningKey(context.Context) (*jose.SigningKey, error)
|
||||||
GetKeySet() (*jose.JSONWebKeySet, error)
|
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OPStorage interface {
|
type OPStorage interface {
|
||||||
GetClientByClientID(string) (Client, error)
|
GetClientByClientID(context.Context, string) (Client, error)
|
||||||
AuthorizeClientIDSecret(string, string) error
|
AuthorizeClientIDSecret(context.Context, string, string) error
|
||||||
GetUserinfoFromScopes([]string) (*oidc.Userinfo, error)
|
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Storage interface {
|
type Storage interface {
|
||||||
|
@ -43,4 +44,5 @@ type AuthRequest interface {
|
||||||
GetScopes() []string
|
GetScopes() []string
|
||||||
GetState() string
|
GetState() string
|
||||||
GetSubject() string
|
GetSubject() string
|
||||||
|
Done() bool
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -31,13 +32,13 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authReq, err := ValidateAccessTokenRequest(tokenReq, exchanger)
|
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
|
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
|
@ -81,8 +82,8 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
|
||||||
return tokenReq, nil
|
return tokenReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||||
authReq, client, err := AuthorizeClient(tokenReq, exchanger)
|
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -95,44 +96,38 @@ func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, exchanger Exc
|
||||||
return authReq, nil
|
return authReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||||
client, err := exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
switch client.GetAuthMethod() {
|
if client.GetAuthMethod() == AuthMethodNone {
|
||||||
case AuthMethodNone:
|
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger.Storage())
|
||||||
authReq, err := AuthorizeCodeChallenge(tokenReq, exchanger.Storage())
|
|
||||||
return authReq, client, err
|
return authReq, client, err
|
||||||
case AuthMethodPost:
|
}
|
||||||
if !exchanger.AuthMethodPostSupported() {
|
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||||
return nil, nil, errors.New("basic not supported")
|
return nil, nil, errors.New("basic not supported")
|
||||||
}
|
}
|
||||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||||
case AuthMethodBasic:
|
|
||||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
|
||||||
default:
|
|
||||||
err = AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, ErrInvalidRequest("invalid code")
|
return nil, nil, ErrInvalidRequest("invalid code")
|
||||||
}
|
}
|
||||||
return authReq, client, nil
|
return authReq, client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeClientIDSecret(clientID, clientSecret string, storage OPStorage) error {
|
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
|
||||||
return storage.AuthorizeClientIDSecret(clientID, clientSecret)
|
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, storage AuthStorage) (AuthRequest, error) {
|
||||||
if tokenReq.CodeVerifier == "" {
|
if tokenReq.CodeVerifier == "" {
|
||||||
return nil, ErrInvalidRequest("code_challenge required")
|
return nil, ErrInvalidRequest("code_challenge required")
|
||||||
}
|
}
|
||||||
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
|
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, nil, storage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidRequest("invalid code")
|
return nil, ErrInvalidRequest("invalid code")
|
||||||
}
|
}
|
||||||
|
@ -142,12 +137,12 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
|
||||||
return authReq, nil
|
return authReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||||
id, err := crypto.Decrypt(code)
|
id, err := crypto.Decrypt(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return storage.AuthRequestByID(id)
|
return storage.AuthRequestByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
|
|
|
@ -15,7 +15,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(scopes)
|
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.MarshalJSON(w, err)
|
utils.MarshalJSON(w, err)
|
||||||
return
|
return
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue