From 5cd0653c33985dc02c1429e4b036e11a7626d30c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 28 Feb 2023 19:58:08 +0200 Subject: [PATCH] add unit tests --- example/server/storage/storage.go | 4 + pkg/oidc/device_authorization.go | 4 +- pkg/op/device.go | 6 +- pkg/op/device_test.go | 287 ++++++++++++++++++++++++++++-- pkg/op/op.go | 2 +- 5 files changed, 278 insertions(+), 25 deletions(-) diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 9afa29b..b49ce1b 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -784,6 +784,10 @@ func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, device } func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + s.lock.Lock() defer s.lock.Unlock() diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go index 969d528..68b8efa 100644 --- a/pkg/oidc/device_authorization.go +++ b/pkg/oidc/device_authorization.go @@ -24,6 +24,6 @@ type DeviceAuthorizationResponse struct { // https://www.rfc-editor.org/rfc/rfc8628#section-3.4, // Device Access Token Request. type DeviceAccessTokenRequest struct { - GrantType GrantType `json:"grant_type"` - DeviceCode string `json:"device_code"` + GrantType GrantType `json:"grant_type" schema:"grant_type"` + DeviceCode string `json:"device_code" schema:"device_code"` } diff --git a/pkg/op/device.go b/pkg/op/device.go index e3706ff..04c06f2 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -46,7 +46,7 @@ var ( } ) -func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { +func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { RequestError(w, r, err) @@ -198,7 +198,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - if !clientAuthenticated && !IsConfidentialType(client) { + if clientAuthenticated != IsConfidentialType(client) { return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). WithDescription("confidential client requires authentication") } @@ -236,7 +236,7 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str return nil, oidc.ErrSlowDown().WithParent(err) } if err != nil { - return nil, err + return nil, oidc.ErrAccessDenied().WithParent(err) } if state.Denied { return state, oidc.ErrAccessDenied() diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 52c3d14..ca68759 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -1,21 +1,135 @@ -package op +package op_test import ( + "context" "crypto/rand" + "crypto/sha256" "encoding/base64" "io" mr "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v2/example/server/storage" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "golang.org/x/text/language" ) -type errReader struct { +var testProvider op.OpenIDProvider + +const ( + testIssuer = "https://localhost:9998/" + pathLoggedOut = "/logged-out" +) + +func init() { + config := &op.Config{ + CryptoKey: sha256.Sum256([]byte("test")), + DefaultLogoutRedirectURI: pathLoggedOut, + CodeMethodS256: true, + AuthMethodPost: true, + AuthMethodPrivateKeyJWT: true, + GrantTypeRefreshToken: true, + RequestObjectSupported: true, + SupportedUILocales: []language.Tag{language.English}, + DeviceAuthorization: op.DeviceAuthorizationConfig{ + Lifetime: 5 * time.Minute, + PollInterval: 5 * time.Second, + UserFormURL: testIssuer + "device", + UserCode: op.UserCodeBase20, + }, + } + + storage.RegisterClients( + storage.NativeClient("native"), + storage.WebClient("web", "secret"), + storage.WebClient("api", "secret"), + ) + + var err error + testProvider, err = op.NewOpenIDProvider(context.TODO(), testIssuer, config, + storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), + ) + if err != nil { + panic(err) + } } -func (errReader) Read([]byte) (int, error) { - return 0, io.ErrUnexpectedEOF +func Test_deviceAuthorizationHandler(t *testing.T) { + req := &oidc.DeviceAuthorizationRequest{ + Scopes: []string{"foo", "bar"}, + ClientID: "web", + } + values := make(url.Values) + testProvider.Encoder().Encode(req, values) + body := strings.NewReader(values.Encode()) + + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + w := httptest.NewRecorder() + + runWithRandReader(mr.New(mr.NewSource(1)), func() { + op.DeviceAuthorizationHandler(testProvider)(w, r) + }) + + result := w.Result() + + assert.Less(t, result.StatusCode, 300) + + got, _ := io.ReadAll(result.Body) + assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) +} + +func TestParseDeviceCodeRequest(t *testing.T) { + tests := []struct { + name string + req *oidc.DeviceAuthorizationRequest + wantErr bool + }{ + { + name: "empty request", + wantErr: true, + }, + /* decoding a SpaceDelimitedArray is broken + https://github.com/zitadel/oidc/issues/295 + { + name: "success", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "web", + }, + }, + */ + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var body io.Reader + if tt.req != nil { + values := make(url.Values) + testProvider.Encoder().Encode(tt.req, values) + body = strings.NewReader(values.Encode()) + } + + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + got, err := op.ParseDeviceCodeRequest(r, testProvider) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.req, got) + }) + } } func runWithRandReader(r io.Reader, f func()) { @@ -31,14 +145,14 @@ func runWithRandReader(r io.Reader, f func()) { func TestNewDeviceCode(t *testing.T) { t.Run("reader error", func(t *testing.T) { runWithRandReader(errReader{}, func() { - _, err := NewDeviceCode(16) + _, err := op.NewDeviceCode(16) require.Error(t, err) }) }) t.Run("different lengths, rand reader", func(t *testing.T) { for i := 1; i <= 32; i++ { - got, err := NewDeviceCode(i) + got, err := op.NewDeviceCode(i) require.NoError(t, err) assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i)) } @@ -62,7 +176,7 @@ func TestNewUserCode(t *testing.T) { { name: "reader error", args: args{ - charset: []rune(CharSetBase20), + charset: []rune(op.CharSetBase20), charAmount: 8, dashInterval: 4, }, @@ -72,7 +186,7 @@ func TestNewUserCode(t *testing.T) { { name: "base20", args: args{ - charset: []rune(CharSetBase20), + charset: []rune(op.CharSetBase20), charAmount: 8, dashInterval: 4, }, @@ -82,7 +196,7 @@ func TestNewUserCode(t *testing.T) { { name: "digits", args: args{ - charset: []rune(CharSetDigits), + charset: []rune(op.CharSetDigits), charAmount: 9, dashInterval: 3, }, @@ -92,7 +206,7 @@ func TestNewUserCode(t *testing.T) { { name: "no dashes", args: args{ - charset: []rune(CharSetDigits), + charset: []rune(op.CharSetDigits), charAmount: 9, }, reader: mr.New(mr.NewSource(1)), @@ -102,9 +216,9 @@ func TestNewUserCode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { runWithRandReader(tt.reader, func() { - got, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) if tt.wantErr { - require.ErrorIs(t, err, io.ErrUnexpectedEOF) + require.ErrorIs(t, err, io.ErrNoProgress) } else { require.NoError(t, err) } @@ -117,12 +231,12 @@ func TestNewUserCode(t *testing.T) { t.Run("crypto/rand", func(t *testing.T) { const testN = 100000 - for _, c := range []UserCodeConfig{UserCodeBase20, UserCodeDigits} { + for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} { t.Run(c.CharSet, func(t *testing.T) { results := make(map[string]int) for i := 0; i < testN; i++ { - code, err := NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval) + code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval) require.NoError(t, err) results[code]++ } @@ -156,7 +270,7 @@ func BenchmarkNewUserCode(b *testing.B) { { name: "math rand, base20", args: args{ - charset: []rune(CharSetBase20), + charset: []rune(op.CharSetBase20), charAmount: 8, dashInterval: 4, }, @@ -165,7 +279,7 @@ func BenchmarkNewUserCode(b *testing.B) { { name: "math rand, digits", args: args{ - charset: []rune(CharSetDigits), + charset: []rune(op.CharSetDigits), charAmount: 9, dashInterval: 3, }, @@ -174,7 +288,7 @@ func BenchmarkNewUserCode(b *testing.B) { { name: "crypto rand, base20", args: args{ - charset: []rune(CharSetBase20), + charset: []rune(op.CharSetBase20), charAmount: 8, dashInterval: 4, }, @@ -183,7 +297,7 @@ func BenchmarkNewUserCode(b *testing.B) { { name: "crypto rand, digits", args: args{ - charset: []rune(CharSetDigits), + charset: []rune(op.CharSetDigits), charAmount: 9, dashInterval: 3, }, @@ -194,7 +308,7 @@ func BenchmarkNewUserCode(b *testing.B) { runWithRandReader(tt.reader, func() { b.Run(tt.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) + _, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval) require.NoError(b, err) } }) @@ -202,3 +316,138 @@ func BenchmarkNewUserCode(b *testing.B) { }) } } + +func TestDeviceAccessToken(t *testing.T) { + storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) + storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") + + values := make(url.Values) + values.Set("client_id", "native") + values.Set("grant_type", string(oidc.GrantTypeDeviceCode)) + values.Set("device_code", "qwerty") + + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode())) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + op.DeviceAccessToken(w, r, testProvider) + + result := w.Result() + got, _ := io.ReadAll(result.Body) + t.Log(string(got)) + assert.Less(t, result.StatusCode, 300) + assert.NotEmpty(t, string(got)) +} + +func TestCheckDeviceAuthorizationState(t *testing.T) { + now := time.Now() + + storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) + storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"}) + + storage.DenyDeviceAuthorization(context.Background(), "denied") + storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim") + + exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second) + defer cancel() + + type args struct { + ctx context.Context + clientID string + deviceCode string + } + tests := []struct { + name string + args args + want *op.DeviceAuthorizationState + wantErr error + }{ + { + name: "pending", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "pending", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + }, + wantErr: oidc.ErrAuthorizationPending(), + }, + { + name: "slow down", + args: args{ + ctx: exceededCtx, + clientID: "native", + deviceCode: "ok", + }, + wantErr: oidc.ErrSlowDown(), + }, + { + name: "wrong client", + args: args{ + ctx: context.Background(), + clientID: "foo", + deviceCode: "ok", + }, + wantErr: oidc.ErrAccessDenied(), + }, + { + name: "denied", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "denied", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + Denied: true, + }, + wantErr: oidc.ErrAccessDenied(), + }, + { + name: "completed", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "completed", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(time.Minute), + Subject: "tim", + Done: true, + }, + }, + { + name: "expired", + args: args{ + ctx: context.Background(), + clientID: "native", + deviceCode: "expired", + }, + want: &op.DeviceAuthorizationState{ + ClientID: "native", + Scopes: []string{"foo"}, + Expires: now.Add(-time.Minute), + }, + wantErr: oidc.ErrExpiredDeviceCode(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 91c89cc..2859722 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -97,7 +97,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) - router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o)) + router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o)) return router }