fix: enforce device authorization grant type (#400)

This commit is contained in:
Tim Möhlmann 2023-05-26 11:52:35 +03:00 committed by GitHub
parent 09bdd1dca2
commit a4dbe2a973
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 5 deletions

View file

@ -193,6 +193,24 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
} }
} }
// DeviceClient creates a device client with Basic authentication.
func DeviceClient(id, secret string) *Client {
return &Client{
id: id,
secret: secret,
redirectURIs: nil,
applicationType: op.ApplicationTypeWeb,
authMethod: oidc.AuthMethodBasic,
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeDeviceCode},
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,
clockSkew: 0,
}
}
type hasRedirectGlobs struct { type hasRedirectGlobs struct {
*Client *Client
} }

View file

@ -122,6 +122,13 @@ func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuth
if err != nil { if err != nil {
return nil, err return nil, err
} }
client, err := o.Storage().GetClientByClientID(r.Context(), clientID)
if err != nil {
return nil, err
}
if !ValidateGrantType(client, oidc.GrantTypeDeviceCode) {
return nil, oidc.ErrUnauthorizedClient().WithDescription("client missing grant type " + string(oidc.GrantTypeCode))
}
req := new(oidc.DeviceAuthorizationRequest) req := new(oidc.DeviceAuthorizationRequest)
if err := o.Decoder().Decode(req, r.Form); err != nil { if err := o.Decoder().Decode(req, r.Form); err != nil {

View file

@ -51,7 +51,7 @@ func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{ req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"}, Scopes: []string{"foo", "bar"},
ClientID: "web", ClientID: "device",
} }
values := make(url.Values) values := make(url.Values)
testProvider.Encoder().Encode(req, values) testProvider.Encoder().Encode(req, values)
@ -88,11 +88,27 @@ func TestParseDeviceCodeRequest(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "success", name: "missing grant type",
req: &oidc.DeviceAuthorizationRequest{ req: &oidc.DeviceAuthorizationRequest{
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
ClientID: "web", ClientID: "web",
}, },
wantErr: true,
},
{
name: "client not found",
req: &oidc.DeviceAuthorizationRequest{
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
ClientID: "foobar",
},
wantErr: true,
},
{
name: "success",
req: &oidc.DeviceAuthorizationRequest{
Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
ClientID: "device",
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -110,8 +126,7 @@ func TestParseDeviceCodeRequest(t *testing.T) {
got, err := op.ParseDeviceCodeRequest(r, testProvider) got, err := op.ParseDeviceCodeRequest(r, testProvider)
if tt.wantErr { if tt.wantErr {
require.Error(t, err) require.Error(t, err)
} else { return
require.NoError(t, err)
} }
assert.Equal(t, tt.req, got) assert.Equal(t, tt.req, got)
}) })

View file

@ -49,6 +49,7 @@ func init() {
storage.RegisterClients( storage.RegisterClients(
storage.NativeClient("native"), storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"), storage.WebClient("web", "secret", "https://example.com"),
storage.DeviceClient("device", "secret"),
storage.WebClient("api", "secret"), storage.WebClient("api", "secret"),
) )
@ -336,7 +337,7 @@ func TestRoutes(t *testing.T) {
name: "device authorization", name: "device authorization",
method: http.MethodGet, method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(), path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"}, basicAuth: &basicAuth{"device", "secret"},
values: map[string]string{ values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
}, },