From 6e71c17f1d68b8e7f2070ea108eb681ad1e070c2 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 24 Aug 2020 07:52:22 +0200 Subject: [PATCH] pass origin into GetUserinfoFromToken --- example/internal/mock/storage.go | 2 +- pkg/op/mock/storage.mock.go | 8 ++++---- pkg/op/op.go | 2 +- pkg/op/storage.go | 2 +- pkg/op/userinfo.go | 6 ++++-- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 7e216de..faa62f0 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -202,7 +202,7 @@ func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ st return nil } -func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _ string) (*oidc.Userinfo, error) { +func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _, _ string) (*oidc.Userinfo, error) { return s.GetUserinfoFromScopes(ctx, "", []string{}) } func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) { diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index ac8ba27..9432616 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -184,18 +184,18 @@ func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interf } // GetUserinfoFromToken mocks base method -func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1 string) (*oidc.Userinfo, error) { +func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1, arg2 string) (*oidc.Userinfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1) + ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1, arg2) ret0, _ := ret[0].(*oidc.Userinfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetUserinfoFromToken indicates an expected call of GetUserinfoFromToken -func (mr *MockStorageMockRecorder) GetUserinfoFromToken(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) GetUserinfoFromToken(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromToken), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromToken), arg0, arg1, arg2) } // Health mocks base method diff --git a/pkg/op/op.go b/pkg/op/op.go index a6561dd..1812a10 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -44,7 +44,7 @@ func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router { h = DefaultInterceptor } router := mux.NewRouter() - router.Use(handlers.CORS(handlers.AllowedOriginValidator(allowAllOrigins))) + router.Use(handlers.CORS(handlers.AllowedOriginValidator(allowAllOrigins), handlers.AllowedHeaders([]string{"content-type"}))) router.HandleFunc(healthzEndpoint, Healthz) router.HandleFunc(readinessEndpoint, o.HandleReady) router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) diff --git a/pkg/op/storage.go b/pkg/op/storage.go index e3ef5ff..17023c1 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -29,7 +29,7 @@ type OPStorage interface { GetClientByClientID(context.Context, string) (Client, error) AuthorizeClientIDSecret(context.Context, string, string) error GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error) - GetUserinfoFromToken(context.Context, string) (*oidc.Userinfo, error) + GetUserinfoFromToken(context.Context, string, string) (*oidc.Userinfo, error) } type Storage interface { diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 69746c7..8f55b15 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -5,9 +5,10 @@ import ( "net/http" "strings" + "github.com/gorilla/schema" + "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" - "github.com/gorilla/schema" ) type UserinfoProvider interface { @@ -27,8 +28,9 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP http.Error(w, "access token missing", http.StatusUnauthorized) return } - info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID) + info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID, r.Header.Get("origin")) if err != nil { + w.WriteHeader(http.StatusForbidden) utils.MarshalJSON(w, err) return }