diff --git a/example/server/storage/client.go b/example/server/storage/client.go index b8b9960..b28d9d4 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -193,6 +193,24 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { } } +// DeviceClient creates a device client with Basic authentication. +func DeviceClient(id, secret string) *Client { + return &Client{ + id: id, + secret: secret, + redirectURIs: nil, + applicationType: op.ApplicationTypeWeb, + authMethod: oidc.AuthMethodBasic, + loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, + grantTypes: []oidc.GrantType{oidc.GrantTypeDeviceCode}, + accessTokenType: op.AccessTokenTypeBearer, + devMode: false, + idTokenUserinfoClaimsAssertion: false, + clockSkew: 0, + } +} + type hasRedirectGlobs struct { *Client } diff --git a/pkg/op/device.go b/pkg/op/device.go index 397cede..f584c31 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -122,6 +122,13 @@ func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuth if err != nil { return nil, err } + client, err := o.Storage().GetClientByClientID(r.Context(), clientID) + if err != nil { + return nil, err + } + if !ValidateGrantType(client, oidc.GrantTypeDeviceCode) { + return nil, oidc.ErrUnauthorizedClient().WithDescription("client missing grant type " + string(oidc.GrantTypeCode)) + } req := new(oidc.DeviceAuthorizationRequest) if err := o.Decoder().Decode(req, r.Form); err != nil { diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index cf94c3f..4b3c98c 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -51,7 +51,7 @@ func Test_deviceAuthorizationHandler(t *testing.T) { req := &oidc.DeviceAuthorizationRequest{ Scopes: []string{"foo", "bar"}, - ClientID: "web", + ClientID: "device", } values := make(url.Values) testProvider.Encoder().Encode(req, values) @@ -88,11 +88,27 @@ func TestParseDeviceCodeRequest(t *testing.T) { wantErr: true, }, { - name: "success", + name: "missing grant type", req: &oidc.DeviceAuthorizationRequest{ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, ClientID: "web", }, + wantErr: true, + }, + { + name: "client not found", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "foobar", + }, + wantErr: true, + }, + { + name: "success", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "device", + }, }, } for _, tt := range tests { @@ -110,8 +126,7 @@ func TestParseDeviceCodeRequest(t *testing.T) { got, err := op.ParseDeviceCodeRequest(r, testProvider) if tt.wantErr { require.Error(t, err) - } else { - require.NoError(t, err) + return } assert.Equal(t, tt.req, got) }) diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index 3e6377f..b637e03 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -49,6 +49,7 @@ func init() { storage.RegisterClients( storage.NativeClient("native"), storage.WebClient("web", "secret", "https://example.com"), + storage.DeviceClient("device", "secret"), storage.WebClient("api", "secret"), ) @@ -336,7 +337,7 @@ func TestRoutes(t *testing.T) { name: "device authorization", method: http.MethodGet, path: testProvider.DeviceAuthorizationEndpoint().Relative(), - basicAuth: &basicAuth{"web", "secret"}, + basicAuth: &basicAuth{"device", "secret"}, values: map[string]string{ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), },