don't obtain a Client from storage on each poll
First verify if the client is authenticated. Then the state of the device authorization. If all is good, we take the Client from Storage.
This commit is contained in:
parent
f26e155208
commit
65cd4528e4
4 changed files with 25 additions and 14 deletions
|
@ -3,7 +3,6 @@ package op
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
@ -156,6 +155,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
|||
return data.ClientID, false, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
|
||||
// If the client id was not authenticated, the client from storage does not have
|
||||
// oidc.AuthMethodNone set, an error is returned.
|
||||
|
@ -179,3 +179,4 @@ func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) {
|
|||
|
||||
return client, err
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -252,6 +252,7 @@ func TestClientIDFromRequest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestClientFromRequest(t *testing.T) {
|
||||
publicClient := func() op.Client {
|
||||
c := mock.NewMockClient(gomock.NewController(t))
|
||||
|
@ -390,3 +391,4 @@ func TestClientFromRequest(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -179,21 +179,34 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
client, err := ClientFromRequest(r, exchanger)
|
||||
clientID, authenticated, err := ClientIDFromRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
|
||||
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !authenticated {
|
||||
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||
}
|
||||
}
|
||||
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{req.ClientID},
|
||||
audience: []string{clientID},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||
|
@ -206,24 +219,20 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
}
|
||||
|
||||
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||
req := new(struct {
|
||||
oidc.DeviceAccessTokenRequest
|
||||
})
|
||||
err := exchanger.Decoder().Decode(req, r.PostForm)
|
||||
if err != nil {
|
||||
req := new(oidc.DeviceAccessTokenRequest)
|
||||
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &req.DeviceAccessTokenRequest, err
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode)
|
||||
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue