don't obtain a Client from storage on each poll
First verify if the client is authenticated. Then the state of the device authorization. If all is good, we take the Client from Storage.
This commit is contained in:
parent
f26e155208
commit
65cd4528e4
4 changed files with 25 additions and 14 deletions
|
@ -26,5 +26,4 @@ type DeviceAuthorizationResponse struct {
|
||||||
type DeviceAccessTokenRequest struct {
|
type DeviceAccessTokenRequest struct {
|
||||||
GrantType string `json:"grant_type"`
|
GrantType string `json:"grant_type"`
|
||||||
DeviceCode string `json:"device_code"`
|
DeviceCode string `json:"device_code"`
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package op
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
@ -156,6 +155,7 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
||||||
return data.ClientID, false, nil
|
return data.ClientID, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
|
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
|
||||||
// If the client id was not authenticated, the client from storage does not have
|
// If the client id was not authenticated, the client from storage does not have
|
||||||
// oidc.AuthMethodNone set, an error is returned.
|
// oidc.AuthMethodNone set, an error is returned.
|
||||||
|
@ -179,3 +179,4 @@ func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) {
|
||||||
|
|
||||||
return client, err
|
return client, err
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -252,6 +252,7 @@ func TestClientIDFromRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
func TestClientFromRequest(t *testing.T) {
|
func TestClientFromRequest(t *testing.T) {
|
||||||
publicClient := func() op.Client {
|
publicClient := func() op.Client {
|
||||||
c := mock.NewMockClient(gomock.NewController(t))
|
c := mock.NewMockClient(gomock.NewController(t))
|
||||||
|
@ -390,3 +391,4 @@ func TestClientFromRequest(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -179,21 +179,34 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
||||||
defer cancel()
|
defer cancel()
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
client, err := ClientFromRequest(r, exchanger)
|
clientID, authenticated, err := ClientIDFromRequest(r, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
|
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !authenticated {
|
||||||
|
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||||
|
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||||
|
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tokenRequest := &deviceAccessTokenRequest{
|
tokenRequest := &deviceAccessTokenRequest{
|
||||||
subject: state.Subject,
|
subject: state.Subject,
|
||||||
audience: []string{req.ClientID},
|
audience: []string{clientID},
|
||||||
scopes: state.Scopes,
|
scopes: state.Scopes,
|
||||||
}
|
}
|
||||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||||
|
@ -206,24 +219,20 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||||
req := new(struct {
|
req := new(oidc.DeviceAccessTokenRequest)
|
||||||
oidc.DeviceAccessTokenRequest
|
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
||||||
})
|
|
||||||
err := exchanger.Decoder().Decode(req, r.PostForm)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return req, nil
|
||||||
return &req.DeviceAccessTokenRequest, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||||
storage, err := assertDeviceStorage(exchanger.Storage())
|
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode)
|
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
||||||
if errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue