From 65cd4528e4df617dc5e653aa458a7e0fd4d90669 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Sun, 26 Feb 2023 18:23:43 +0100 Subject: [PATCH] don't obtain a Client from storage on each poll First verify if the client is authenticated. Then the state of the device authorization. If all is good, we take the Client from Storage. --- pkg/oidc/device_authorization.go | 1 - pkg/op/client.go | 3 ++- pkg/op/client_test.go | 2 ++ pkg/op/device.go | 33 ++++++++++++++++++++------------ 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go index e8862d3..8ff8cee 100644 --- a/pkg/oidc/device_authorization.go +++ b/pkg/oidc/device_authorization.go @@ -26,5 +26,4 @@ type DeviceAuthorizationResponse struct { type DeviceAccessTokenRequest struct { GrantType string `json:"grant_type"` DeviceCode string `json:"device_code"` - ClientID string `json:"client_id"` } diff --git a/pkg/op/client.go b/pkg/op/client.go index 48c6241..105d90b 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -3,7 +3,6 @@ package op import ( "context" "errors" - "fmt" "net/http" "net/url" "time" @@ -156,6 +155,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au return data.ClientID, false, nil } +/* // ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage. // If the client id was not authenticated, the client from storage does not have // oidc.AuthMethodNone set, an error is returned. @@ -179,3 +179,4 @@ func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) { return client, err } +*/ diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index f42c647..5f3560f 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -252,6 +252,7 @@ func TestClientIDFromRequest(t *testing.T) { } } +/* func TestClientFromRequest(t *testing.T) { publicClient := func() op.Client { c := mock.NewMockClient(gomock.NewController(t)) @@ -390,3 +391,4 @@ func TestClientFromRequest(t *testing.T) { }) } } +*/ diff --git a/pkg/op/device.go b/pkg/op/device.go index bac80a4..ab276b8 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -179,21 +179,34 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang defer cancel() r = r.WithContext(ctx) - client, err := ClientFromRequest(r, exchanger) + clientID, authenticated, err := ClientIDFromRequest(r, exchanger) if err != nil { return err } + req, err := ParseDeviceAccessTokenRequest(r, exchanger) if err != nil { return err } - state, err := CheckDeviceAuthorizationState(ctx, req, exchanger) + state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) if err != nil { return err } + + client, err := exchanger.Storage().GetClientByClientID(ctx, clientID) + if err != nil { + return err + } + if !authenticated { + if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? + return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). + WithDescription(fmt.Sprintf("required client auth method: %s", m)) + } + } + tokenRequest := &deviceAccessTokenRequest{ subject: state.Subject, - audience: []string{req.ClientID}, + audience: []string{clientID}, scopes: state.Scopes, } resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) @@ -206,24 +219,20 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang } func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) { - req := new(struct { - oidc.DeviceAccessTokenRequest - }) - err := exchanger.Decoder().Decode(req, r.PostForm) - if err != nil { + req := new(oidc.DeviceAccessTokenRequest) + if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { return nil, err } - - return &req.DeviceAccessTokenRequest, err + return req, nil } -func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) { +func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { return nil, err } - state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode) + state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode) if errors.Is(err, context.DeadlineExceeded) { return nil, oidc.ErrSlowDown().WithParent(err) }