From caedc72d451d2e11885b93ebe147609fb68c3b82 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Fri, 29 Nov 2019 15:03:13 +0100 Subject: [PATCH] tests --- pkg/op/config.go | 28 ++++++++++ pkg/op/config_test.go | 67 ++++++++++++++++++++++ pkg/op/discovery_test.go | 60 ++++++++++++++++++++ pkg/op/endpoint_test.go | 95 ++++++++++++++++++++++++++++++++ pkg/op/mock/storage.mock.go | 4 +- pkg/op/mock/storage.mock.impl.go | 34 ++++++++++++ pkg/op/op.go | 25 --------- pkg/op/signer.go | 2 +- pkg/op/signer_test.go | 95 ++++++++++++++++++++++++++++++++ pkg/op/storage.go | 2 +- pkg/op/tokenrequest.go | 23 -------- 11 files changed, 383 insertions(+), 52 deletions(-) create mode 100644 pkg/op/config_test.go create mode 100644 pkg/op/discovery_test.go create mode 100644 pkg/op/endpoint_test.go create mode 100644 pkg/op/signer_test.go diff --git a/pkg/op/config.go b/pkg/op/config.go index d4ef8b4..151265e 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -1,5 +1,11 @@ package op +import ( + "errors" + "net/url" + "strings" +) + type Configuration interface { Issuer() string AuthorizationEndpoint() Endpoint @@ -7,3 +13,25 @@ type Configuration interface { UserinfoEndpoint() Endpoint Port() string } + +func ValidateIssuer(issuer string) error { + if issuer == "" { + return errors.New("missing issuer") + } + u, err := url.Parse(issuer) + if err != nil { + return errors.New("invalid url for issuer") + } + if u.Host == "" { + return errors.New("host for issuer missing") + } + if u.Scheme != "https" { + if !(u.Scheme == "http" && (u.Host == "localhost" || u.Host == "127.0.0.1" || u.Host == "::1" || strings.HasPrefix(u.Host, "localhost:"))) { //TODO: ? + return errors.New("scheme for issuer must be `https`") + } + } + if u.Fragment != "" || len(u.Query()) > 0 { + return errors.New("no fragments or query allowed for issuer") + } + return nil +} diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go new file mode 100644 index 0000000..b5f508b --- /dev/null +++ b/pkg/op/config_test.go @@ -0,0 +1,67 @@ +package op + +import "testing" + +func TestValidateIssuer(t *testing.T) { + type args struct { + issuer string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "missing issuer fails", + args{""}, + true, + }, + { + "invalid url for issuer fails", + args{":issuer"}, + true, + }, + { + "invalid url for issuer fails", + args{":issuer"}, + true, + }, + { + "host for issuer missing fails", + args{"https:///issuer"}, + true, + }, + { + "host for not https fails", + args{"http://issuer.com"}, + true, + }, + { + "host with fragment fails", + args{"https://issuer.com/#issuer"}, + true, + }, + { + "host with query fails", + args{"https://issuer.com?issuer=me"}, + true, + }, + { + "host with https ok", + args{"https://issuer.com"}, + false, + }, + { + "localhost with http ok", + args{"http://localhost:9999"}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { + t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go new file mode 100644 index 0000000..0df04a5 --- /dev/null +++ b/pkg/op/discovery_test.go @@ -0,0 +1,60 @@ +package op_test + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" +) + +func TestDiscover(t *testing.T) { + type args struct { + w http.ResponseWriter + config *oidc.DiscoveryConfiguration + } + tests := []struct { + name string + args args + }{ + { + "OK", + args{ + httptest.NewRecorder(), + &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.Discover(tt.args.w, tt.args.config) + rec := tt.args.w.(*httptest.ResponseRecorder) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String()) + }) + } +} + +func TestCreateDiscoveryConfig(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want *oidc.DiscoveryConfiguration + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.CreateDiscoveryConfig(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go new file mode 100644 index 0000000..227bf9d --- /dev/null +++ b/pkg/op/endpoint_test.go @@ -0,0 +1,95 @@ +package op_test + +import ( + "testing" + + "github.com/caos/oidc/pkg/op" +) + +func TestEndpoint_Relative(t *testing.T) { + tests := []struct { + name string + e op.Endpoint + want string + }{ + { + "without starting /", + op.Endpoint("test"), + "/test", + }, + { + "with starting /", + op.Endpoint("/test"), + "/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.Relative(); got != tt.want { + t.Errorf("Endpoint.Relative() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEndpoint_Absolute(t *testing.T) { + type args struct { + host string + } + tests := []struct { + name string + e op.Endpoint + args args + want string + }{ + { + "no /", + op.Endpoint("test"), + args{"https://host"}, + "https://host/test", + }, + { + "endpoint without /", + op.Endpoint("test"), + args{"https://host/"}, + "https://host/test", + }, + { + "host without /", + op.Endpoint("/test"), + args{"https://host"}, + "https://host/test", + }, + { + "both /", + op.Endpoint("/test"), + args{"https://host/"}, + "https://host/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.Absolute(tt.args.host); got != tt.want { + t.Errorf("Endpoint.Absolute() = %v, want %v", got, tt.want) + } + }) + } +} + +//TODO: impl test +func TestEndpoint_Validate(t *testing.T) { + tests := []struct { + name string + e op.Endpoint + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.e.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 7ec490c..3c7bc1d 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock } // GetSigningKey mocks base method -func (m *MockStorage) GetSigningKey() (go_jose_v2.SigningKey, error) { +func (m *MockStorage) GetSigningKey() (*go_jose_v2.SigningKey, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSigningKey") - ret0, _ := ret[0].(go_jose_v2.SigningKey) + ret0, _ := ret[0].(*go_jose_v2.SigningKey) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 6bac735..fd9cd76 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -4,6 +4,8 @@ import ( "errors" "testing" + "gopkg.in/square/go-jose.v2" + "github.com/golang/mock/gomock" "github.com/caos/oidc/pkg/op" @@ -33,6 +35,23 @@ func NewMockStorageAny(t *testing.T) op.Storage { return m } +func NewMockStorageSigningKeyError(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKeyError(m) + return m +} + +func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKeyInvalid(m) + return m +} +func NewMockStorageSigningKey(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKey(m) + return m +} + func ExpectInvalidClientID(s op.Storage) { mockS := s.(*MockStorage) mockS.EXPECT().GetClientByClientID(gomock.Any()).Return(nil, errors.New("client not found")) @@ -55,6 +74,21 @@ func ExpectValidClientID(s op.Storage) { }) } +func ExpectSigningKeyError(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey().Return(nil, errors.New("error")) +} + +func ExpectSigningKeyInvalid(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey().Return(&jose.SigningKey{}, nil) +} + +func ExpectSigningKey(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey().Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil) +} + type ConfClient struct { appType op.ApplicationType } diff --git a/pkg/op/op.go b/pkg/op/op.go index 8d4c36f..e3f5f70 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -2,10 +2,7 @@ package op import ( "context" - "errors" "net/http" - "net/url" - "strings" "github.com/gorilla/mux" @@ -25,28 +22,6 @@ type OpenIDProvider interface { HttpHandler() *http.Server } -func ValidateIssuer(issuer string) error { - if issuer == "" { - return errors.New("missing issuer") - } - u, err := url.Parse(issuer) - if err != nil { - return errors.New("invalid url for issuer") - } - if u.Host == "" { - return errors.New("host for issuer missing") - } - if u.Scheme != "https" { - if !(u.Scheme == "http" && (u.Host == "localhost" || u.Host == "127.0.0.1" || u.Host == "::1" || strings.HasPrefix(u.Host, "localhost:"))) { //TODO: ? - return errors.New("scheme for issuer must be `https`") - } - } - if u.Fragment != "" || len(u.Query()) > 0 { - return errors.New("no fragments or query allowed for issuer") - } - return nil -} - func CreateRouter(o OpenIDProvider) *mux.Router { router := mux.NewRouter() router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) diff --git a/pkg/op/signer.go b/pkg/op/signer.go index f94412b..fd36652 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -32,7 +32,7 @@ func (s *idTokenSigner) initialize() error { if err != nil { return err } - s.signer, err = jose.NewSigner(key, &jose.SignerOptions{}) + s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{}) if err != nil { return err } diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go new file mode 100644 index 0000000..21aab0d --- /dev/null +++ b/pkg/op/signer_test.go @@ -0,0 +1,95 @@ +package op + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" +) + +// func TestNewDefaultSigner(t *testing.T) { +// type args struct { +// storage Storage +// } +// tests := []struct { +// name string +// args args +// want Signer +// wantErr bool +// }{ +// { +// "err initialize storage fails", +// args{mock.NewMockStorageSigningKeyError(t)}, +// nil, +// true, +// }, +// { +// "err initialize storage fails", +// args{mock.NewMockStorageSigningKeyInvalid(t)}, +// nil, +// true, +// }, +// { +// "initialize ok", +// args{mock.NewMockStorageSigningKey(t)}, +// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)}, +// false, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// got, err := op.NewDefaultSigner(tt.args.storage) +// if (err != nil) != tt.wantErr { +// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr) +// return +// } +// if !reflect.DeepEqual(got, tt.want) { +// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want) +// } +// }) +// } +// } + +func Test_idTokenSigner_Sign(t *testing.T) { + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{}) + require.NoError(t, err) + + type fields struct { + signer jose.Signer + storage Storage + } + type args struct { + payload []byte + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + "ok", + fields{signer, nil}, + args{[]byte("test")}, + "eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &idTokenSigner{ + signer: tt.fields.signer, + storage: tt.fields.storage, + } + got, err := s.Sign(tt.args.payload) + if (err != nil) != tt.wantErr { + t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 105306e..7db58d7 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -14,7 +14,7 @@ type Storage interface { AuthorizeClientIDSecret(string, string) (Client, error) AuthorizeClientIDCodeVerifier(string, string) (Client, error) DeleteAuthRequestAndCode(string, string) error - GetSigningKey() (jose.SigningKey, error) + GetSigningKey() (*jose.SigningKey, error) } type AuthRequest interface { diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index c4e1c3f..0f62a0f 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -14,17 +14,6 @@ import ( "github.com/caos/oidc/pkg/oidc" ) -// func ParseTokenRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { -// reqType := r.FormValue("grant_type") -// if reqType == "" { -// return nil, errors.New("grant_type missing") //TODO: impl -// } -// if reqType == string(oidc.GrantTypeCode) { -// return ParseAccessTokenRequest(w, r) -// } -// return ParseTokenExchangeRequest(w, r) -// } - type Exchanger interface { Storage() Storage Decoder() *schema.Decoder @@ -111,18 +100,6 @@ func CreateIDToken(issuer string, authReq AuthRequest, sub string, exp, authTime return signer.SignIDToken(claims) } -type Signe struct { - signer jose.Signer -} - -func (s *Signe) Sign(payload []byte) (string, error) { - result, err := s.signer.Sign(payload) - if err != nil { - return "", err - } - return result.CompactSerialize() -} - func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, storage Storage) (Client, error) { if tokenReq.ClientID == "" { clientID, clientSecret, ok := r.BasicAuth()