don't obtain a Client from storage on each poll
First verify if the client is authenticated. Then the state of the device authorization. If all is good, we take the Client from Storage.
This commit is contained in:
parent
f26e155208
commit
65cd4528e4
4 changed files with 25 additions and 14 deletions
|
@ -179,21 +179,34 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
client, err := ClientFromRequest(r, exchanger)
|
||||
clientID, authenticated, err := ClientIDFromRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
|
||||
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !authenticated {
|
||||
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||
}
|
||||
}
|
||||
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{req.ClientID},
|
||||
audience: []string{clientID},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||
|
@ -206,24 +219,20 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
}
|
||||
|
||||
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||
req := new(struct {
|
||||
oidc.DeviceAccessTokenRequest
|
||||
})
|
||||
err := exchanger.Decoder().Decode(req, r.PostForm)
|
||||
if err != nil {
|
||||
req := new(oidc.DeviceAccessTokenRequest)
|
||||
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &req.DeviceAccessTokenRequest, err
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode)
|
||||
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue