From 4a10ffaaa2fa5f97adaef272138a379f45b4b93b Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Mon, 13 Jul 2020 07:30:00 +0200 Subject: [PATCH] en/decoding abstraction --- pkg/op/authrequest.go | 52 +++++-- pkg/op/authrequest_test.go | 247 ++++++++++++++++++++++++++---- pkg/op/config_test.go | 2 +- pkg/op/default_op.go | 5 +- pkg/op/error.go | 4 +- pkg/op/mock/authorizer.mock.go | 10 +- pkg/op/session.go | 6 +- pkg/op/tokenrequest.go | 6 +- pkg/op/userinfo.go | 5 +- pkg/rp/mock/generate.go | 3 + pkg/rp/mock/verifier.mock.go | 50 ++++++ pkg/rp/mock/verifier.mock.impl.go | 37 +++++ pkg/utils/http.go | 9 +- 13 files changed, 367 insertions(+), 69 deletions(-) create mode 100644 pkg/rp/mock/generate.go create mode 100644 pkg/rp/mock/verifier.mock.go create mode 100644 pkg/rp/mock/verifier.mock.impl.go diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index c25f60d..19855ba 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/gorilla/mux" - "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" @@ -17,33 +16,37 @@ import ( type Authorizer interface { Storage() Storage - Decoder() *schema.Decoder - Encoder() *schema.Encoder + Decoder() utils.Decoder + Encoder() utils.Encoder Signer() Signer IDTokenVerifier() rp.Verifier Crypto() Crypto Issuer() string } -type ValidationAuthorizer interface { +//AuthorizeValidator is an extension of Authorizer interface +//implementing it's own validation mechanism for the auth request +type AuthorizeValidator interface { Authorizer ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) } +//ValidationAuthorizer is an extension of Authorizer interface +//implementing it's own validation mechanism for the auth request +// +//Deprecated: ValidationAuthorizer exists for historical compatibility. Use ValidationAuthorizer itself +type ValidationAuthorizer AuthorizeValidator + +//Authorize handles the authorization request, including +//parsing, validating, storing and finally redirecting to the login handler func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - err := r.ParseForm() + authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder()) - return - } - authReq := new(oidc.AuthRequest) - err = authorizer.Decoder().Decode(authReq, r.Form) - if err != nil { - AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } validation := ValidateAuthRequest - if validater, ok := authorizer.(ValidationAuthorizer); ok { + if validater, ok := authorizer.(AuthorizeValidator); ok { validation = validater.ValidateAuthRequest } userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier()) @@ -64,6 +67,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } +func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, ErrInvalidRequest("cannot parse form") + } + authReq := new(oidc.AuthRequest) + err = decoder.Decode(authReq, r.Form) + if err != nil { + return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)) + } + return authReq, nil +} + func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { return "", err @@ -121,10 +137,16 @@ func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, resp } func ValidateAuthReqResponseType(responseType oidc.ResponseType) error { - if responseType == "" { + switch responseType { + case oidc.ResponseTypeCode, + oidc.ResponseTypeIDToken, + oidc.ResponseTypeIDTokenOnly: + return nil + case "": return ErrInvalidRequest("response_type empty") + default: + return ErrInvalidRequest("response_type invalid") } - return nil } func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go index dca72fa..cbc48c7 100644 --- a/pkg/op/authrequest_test.go +++ b/pkg/op/authrequest_test.go @@ -3,66 +3,140 @@ package op_test import ( "net/http" "net/http/httptest" - "strings" + "net/url" + "reflect" "testing" + "github.com/gorilla/schema" "github.com/stretchr/testify/require" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op/mock" "github.com/caos/oidc/pkg/rp" + rp_mock "github.com/caos/oidc/pkg/rp/mock" + "github.com/caos/oidc/pkg/utils" ) -func TestAuthorize(t *testing.T) { - // testCallback := func(t *testing.T, clienID string) callbackHandler { - // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { - // // require.Equal(t, clientID, client.) - // } - // } - // testErr := func(t *testing.T, expected error) errorHandler { - // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { - // require.Equal(t, expected, err) - // } - // } +// +//func TestAuthorize(t *testing.T) { +// // testCallback := func(t *testing.T, clienID string) callbackHandler { +// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { +// // // require.Equal(t, clientID, client.) +// // } +// // } +// // testErr := func(t *testing.T, expected error) errorHandler { +// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { +// // require.Equal(t, expected, err) +// // } +// // } +// type args struct { +// w http.ResponseWriter +// r *http.Request +// authorizer op.Authorizer +// } +// tests := []struct { +// name string +// args args +// }{ +// { +// "parsing fails", +// args{ +// httptest.NewRecorder(), +// &http.Request{Method: "POST", Body: nil}, +// mock.NewAuthorizerExpectValid(t, true), +// // testCallback(t, ""), +// // testErr(t, ErrInvalidRequest("cannot parse form")), +// }, +// }, +// { +// "decoding fails", +// args{ +// httptest.NewRecorder(), +// func() *http.Request { +// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) +// r.Header.Set("Content-Type", "application/x-www-form-urlencoded") +// return r +// }(), +// mock.NewAuthorizerExpectValid(t, true), +// // testCallback(t, ""), +// // testErr(t, ErrInvalidRequest("cannot parse auth request")), +// }, +// }, +// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) +// }) +// } +//} + +func TestParseAuthorizeRequest(t *testing.T) { type args struct { - w http.ResponseWriter - r *http.Request - authorizer op.Authorizer + r *http.Request + decoder utils.Decoder + } + type res struct { + want *oidc.AuthRequest + err bool } tests := []struct { name string args args + res res }{ { - "parsing fails", + "parsing form error", args{ - httptest.NewRecorder(), - &http.Request{Method: "POST", Body: nil}, - mock.NewAuthorizerExpectValid(t, true), - // testCallback(t, ""), - // testErr(t, ErrInvalidRequest("cannot parse form")), + &http.Request{URL: &url.URL{RawQuery: "invalid=%%param"}}, + schema.NewDecoder(), + }, + res{ + nil, + true, }, }, { - "decoding fails", + "decoding error", args{ - httptest.NewRecorder(), - func() *http.Request { - r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - return r + &http.Request{URL: &url.URL{RawQuery: "unknown=value"}}, + func() utils.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(false) + return decoder }(), - mock.NewAuthorizerExpectValid(t, true), - // testCallback(t, ""), - // testErr(t, ErrInvalidRequest("cannot parse auth request")), + }, + res{ + nil, + true, + }, + }, + { + "parsing ok", + args{ + &http.Request{URL: &url.URL{RawQuery: "scope=openid"}}, + func() utils.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(false) + return decoder + }(), + }, + res{ + &oidc.AuthRequest{Scopes: oidc.Scopes{"openid"}}, + false, }, }, - // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) + got, err := op.ParseAuthorizeRequest(tt.args.r, tt.args.decoder) + if (err != nil) != tt.res.err { + t.Errorf("ParseAuthorizeRequest() error = %v, wantErr %v", err, tt.res.err) + } + if !reflect.DeepEqual(got, tt.res.want) { + t.Errorf("ParseAuthorizeRequest() got = %v, want %v", got, tt.res.want) + } }) } } @@ -228,6 +302,115 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { } } +func TestValidateAuthReqResponseType(t *testing.T) { + type args struct { + responseType oidc.ResponseType + } + type res struct { + err bool + } + tests := []struct { + name string + args args + res res + }{ + { + "code no error", + args{"code"}, + res{false}, + }, + { + "id_token token no error", + args{"id_token token"}, + res{false}, + }, + { + "id_token no error", + args{"id_token"}, + res{false}, + }, + { + "no response_type error", + args{}, + res{true}, + }, + { + "invalid response_type error", + args{"invalid"}, + res{true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := op.ValidateAuthReqResponseType(tt.args.responseType); (err != nil) != tt.res.err { + t.Errorf("ValidateAuthReqResponseType() error = %v, wantErr %v", err, tt.res.err) + } + }) + } +} + +func TestValidateAuthReqIDTokenHint(t *testing.T) { + type args struct { + idTokenHint string + verifier rp.Verifier + } + type res struct { + userID string + err bool + } + tests := []struct { + name string + args args + res res + }{ + { + "no id_token_hint, no id and ok", + args{ + "", + nil, + }, + res{ + "", + false, + }, + }, + { + "invalid id_token_hint, no id and error", + args{ + "invalid", + rp_mock.NewMockVerifierExpectInvalid(t), + }, + res{ + "", + true, + }, + }, + { + "no id_token_hint ok", + args{ + "valid", + rp_mock.NewMockVerifierExpectValid(t), + }, + res{ + "id", + false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.ValidateAuthReqIDTokenHint(nil, tt.args.idTokenHint, tt.args.verifier) + if (err != nil) != tt.res.err { + t.Errorf("ValidateAuthReqIDTokenHint() error = %v, wantErr %v", err, tt.res.err) + return + } + if got != tt.res.userID { + t.Errorf("ValidateAuthReqIDTokenHint() got = %v, want %v", got, tt.res.userID) + } + }) + } +} + func TestRedirectToLogin(t *testing.T) { type args struct { authReqID string diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index 56cf2eb..8b4c755 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -56,7 +56,7 @@ func TestValidateIssuer(t *testing.T) { { "localhost with http ok", args{"http://localhost:9999"}, - true, + false, }, } for _, tt := range tests { diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index a16d4d3..7d2256a 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -13,6 +13,7 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" ) const ( @@ -247,11 +248,11 @@ func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignat return payload, err } -func (p *DefaultOP) Decoder() *schema.Decoder { +func (p *DefaultOP) Decoder() utils.Decoder { return p.decoder } -func (p *DefaultOP) Encoder() *schema.Encoder { +func (p *DefaultOP) Encoder() utils.Encoder { return p.encoder } diff --git a/pkg/op/error.go b/pkg/op/error.go index f3c5857..b88feca 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -4,8 +4,6 @@ import ( "fmt" "net/http" - "github.com/gorilla/schema" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) @@ -45,7 +43,7 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) { if authReq == nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index dbfc2a6..5272997 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -7,8 +7,8 @@ package mock import ( op "github.com/caos/oidc/pkg/op" rp "github.com/caos/oidc/pkg/rp" + utils "github.com/caos/oidc/pkg/utils" gomock "github.com/golang/mock/gomock" - schema "github.com/gorilla/schema" reflect "reflect" ) @@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call { } // Decoder mocks base method -func (m *MockAuthorizer) Decoder() *schema.Decoder { +func (m *MockAuthorizer) Decoder() utils.Decoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Decoder") - ret0, _ := ret[0].(*schema.Decoder) + ret0, _ := ret[0].(utils.Decoder) return ret0 } @@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call { } // Encoder mocks base method -func (m *MockAuthorizer) Encoder() *schema.Encoder { +func (m *MockAuthorizer) Encoder() utils.Encoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Encoder") - ret0, _ := ret[0].(*schema.Encoder) + ret0, _ := ret[0].(utils.Encoder) return ret0 } diff --git a/pkg/op/session.go b/pkg/op/session.go index c274bf0..cf9f97f 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -6,11 +6,11 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/rp" - "github.com/gorilla/schema" + "github.com/caos/oidc/pkg/utils" ) type SessionEnder interface { - Decoder() *schema.Decoder + Decoder() utils.Decoder Storage() Storage IDTokenVerifier() rp.Verifier DefaultLogoutRedirectURI() string @@ -39,7 +39,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { http.Redirect(w, r, session.RedirectURI, http.StatusFound) } -func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) { +func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) { err := r.ParseForm() if err != nil { return nil, ErrInvalidRequest("error parsing form") diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 5ef4b22..f3f3979 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -5,8 +5,6 @@ import ( "errors" "net/http" - "github.com/gorilla/schema" - "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" ) @@ -14,7 +12,7 @@ import ( type Exchanger interface { Issuer() string Storage() Storage - Decoder() *schema.Decoder + Decoder() utils.Decoder Signer() Signer Crypto() Crypto AuthMethodPostSupported() bool @@ -42,7 +40,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) { +func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) { err := r.ParseForm() if err != nil { return nil, ErrInvalidRequest("error parsing form") diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 69746c7..fa62a6e 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -7,11 +7,10 @@ import ( "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/utils" - "github.com/gorilla/schema" ) type UserinfoProvider interface { - Decoder() *schema.Decoder + Decoder() utils.Decoder Crypto() Crypto Storage() Storage } @@ -35,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP utils.MarshalJSON(w, info) } -func getAccessToken(r *http.Request, decoder *schema.Decoder) (string, error) { +func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) { authHeader := r.Header.Get("authorization") if authHeader != "" { parts := strings.Split(authHeader, "Bearer ") diff --git a/pkg/rp/mock/generate.go b/pkg/rp/mock/generate.go new file mode 100644 index 0000000..71bc3be --- /dev/null +++ b/pkg/rp/mock/generate.go @@ -0,0 +1,3 @@ +package mock + +//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/caos/oidc/pkg/rp Verifier diff --git a/pkg/rp/mock/verifier.mock.go b/pkg/rp/mock/verifier.mock.go new file mode 100644 index 0000000..d53f208 --- /dev/null +++ b/pkg/rp/mock/verifier.mock.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/rp (interfaces: Verifier) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + oidc "github.com/caos/oidc/pkg/oidc" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockVerifier is a mock of Verifier interface +type MockVerifier struct { + ctrl *gomock.Controller + recorder *MockVerifierMockRecorder +} + +// MockVerifierMockRecorder is the mock recorder for MockVerifier +type MockVerifierMockRecorder struct { + mock *MockVerifier +} + +// NewMockVerifier creates a new mock instance +func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier { + mock := &MockVerifier{ctrl: ctrl} + mock.recorder = &MockVerifierMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder { + return m.recorder +} + +// Verify mocks base method +func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2) + ret0, _ := ret[0].(*oidc.IDTokenClaims) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Verify indicates an expected call of Verify +func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2) +} diff --git a/pkg/rp/mock/verifier.mock.impl.go b/pkg/rp/mock/verifier.mock.impl.go new file mode 100644 index 0000000..8a94809 --- /dev/null +++ b/pkg/rp/mock/verifier.mock.impl.go @@ -0,0 +1,37 @@ +package mock + +import ( + "errors" + "testing" + + "github.com/golang/mock/gomock" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" +) + +func NewVerifier(t *testing.T) rp.Verifier { + return NewMockVerifier(gomock.NewController(t)) +} + +func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier { + m := NewVerifier(t) + ExpectVerifyInvalid(m) + return m +} + +func ExpectVerifyInvalid(v rp.Verifier) { + mock := v.(*MockVerifier) + mock.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid")) +} + +func NewMockVerifierExpectValid(t *testing.T) rp.Verifier { + m := NewVerifier(t) + ExpectVerifyValid(m) + return m +} + +func ExpectVerifyValid(v rp.Verifier) { + mock := v.(*MockVerifier) + mock.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil) +} diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 6ad7083..b3ed631 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -18,6 +18,13 @@ var ( } ) +type Decoder interface { + Decode(dst interface{}, src map[string][]string) error +} +type Encoder interface { + Encode(src interface{}, dst map[string][]string) error +} + func FormRequest(endpoint string, request interface{}) (*http.Request, error) { form := make(map[string][]string) encoder := schema.NewEncoder() @@ -56,7 +63,7 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e return nil } -func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) { +func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) { values := make(map[string][]string) err := encoder.Encode(resp, values) if err != nil {