some updates after feedback

This commit is contained in:
Tim Möhlmann 2023-02-23 14:26:55 +01:00
parent 671b13b9c6
commit c12305457b
4 changed files with 150 additions and 69 deletions

View file

@ -24,8 +24,7 @@ type DeviceAuthorizationResponse struct {
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4, // https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
// Device Access Token Request. // Device Access Token Request.
type DeviceAccessTokenRequest struct { type DeviceAccessTokenRequest struct {
JWTTokenRequest
GrantType string `json:"grant_type"` GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code"` DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"` // required, how?? ClientID string `json:"client_id"`
} }

View file

@ -105,7 +105,7 @@ var (
Description: "The authorization request was denied.", Description: "The authorization request was denied.",
} }
} }
ErrExpiredToken = func() *Error { ErrExpiredDeviceCode = func() *Error {
return &Error{ return &Error{
ErrorType: ExpiredToken, ErrorType: ExpiredToken,
Description: "The \"device_code\" has expired.", Description: "The \"device_code\" has expired.",

View file

@ -4,11 +4,12 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/oidc"
@ -17,6 +18,7 @@ import (
type DeviceAuthorizationConfig struct { type DeviceAuthorizationConfig struct {
Lifetime int Lifetime int
PollInterval int PollInterval int
UserFormURL string
UserCode UserCodeConfig UserCode UserCodeConfig
} }
@ -24,8 +26,6 @@ type UserCodeConfig struct {
CharSet string CharSet string
CharAmount int CharAmount int
DashInterval int DashInterval int
QueryKey string
FormHTML []byte
} }
const ( const (
@ -38,13 +38,11 @@ var (
CharSet: CharSetBase20, CharSet: CharSetBase20,
CharAmount: 8, CharAmount: 8,
DashInterval: 4, DashInterval: 4,
QueryKey: "user_code",
} }
UserCodeDigits = UserCodeConfig{ UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits, CharSet: CharSetDigits,
CharAmount: 9, CharAmount: 9,
DashInterval: 3, DashInterval: 3,
QueryKey: "user_code",
} }
) )
@ -55,10 +53,12 @@ func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt
} }
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
storage, ok := o.Storage().(DeviceCodeStorage) storage, err := assertDeviceStorage(o.Storage())
if !ok { if err != nil {
// unimplemented error? RequestError(w, r, err)
return
} }
req, err := ParseDeviceCodeRequest(r, o.Decoder()) req, err := ParseDeviceCodeRequest(r, o.Decoder())
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
@ -77,25 +77,20 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
err = storage.StoreDeviceAuthorizationRequest(r.Context(), req, deviceCode, userCode) err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, req.Scopes)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
response := &oidc.DeviceAuthorizationResponse{ response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode, DeviceCode: deviceCode,
UserCode: userCode, UserCode: userCode,
VerificationURI: endpoint, VerificationURI: config.UserFormURL,
} }
if key := config.UserCode.QueryKey; key != "" { endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
vals := make(url.Values, 1) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
vals.Set(key, userCode)
response.VerificationURIComplete = strings.Join([]string{endpoint, vals.Encode()}, "?")
}
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
} }
@ -107,7 +102,7 @@ func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.
devReq := new(oidc.DeviceAuthorizationRequest) devReq := new(oidc.DeviceAuthorizationRequest)
if err := decoder.Decode(devReq, r.Form); err != nil { if err := decoder.Decode(devReq, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse dev auth request").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
} }
return devReq, nil return devReq, nil
@ -151,23 +146,49 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
return buf.String(), nil return buf.String(), nil
} }
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
req := new(oidc.DeviceAccessTokenRequest) req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return
} }
storage, ok := exchanger.Storage().(DeviceCodeStorage) // use a limited context timeout shorter as the default
if !ok { // poll interval of 5 seconds.
// unimplemented error? ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
} defer cancel()
client, err := storage.DeviceAccessPoll(r.Context(), req.DeviceCode) state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return
} }
resp, err := CreateDeviceTokenResponse(r.Context(), req, exchanger, client) tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{req.ClientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, &jwtProfileClient{id: req.ClientID})
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
@ -175,16 +196,42 @@ func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
httphelper.MarshalJSON(w, resp) httphelper.MarshalJSON(w, resp)
} }
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil {
return nil, err
}
state, err := storage.GetDeviceAuthorizatonState(ctx, req.ClientID, req.DeviceCode)
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)
}
if err != nil {
return nil, err
}
if state.Denied {
return state, oidc.ErrAccessDenied()
}
if state.Completed {
return state, nil
}
if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode()
}
return state, oidc.ErrAuthorizationPending()
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type? tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "") accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &oidc.AccessTokenResponse{ return &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken, TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()), ExpiresIn: uint64(validity.Seconds()),
}, nil }, nil
@ -196,37 +243,62 @@ func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
} }
} }
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { type UserCodeFormData struct {
// check cookie, or what?? AccesssToken string `schema:"access_token"`
UserCode string `schema:"user_code"`
RedirectURL string `schema:"redirect_url"`
}
config := o.DeviceAuthorization().UserCode func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
userCode, err := UserCodeFromRequest(r, config.QueryKey) data, err := ParseUserCodeFormData(r, o.Decoder())
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
if userCode == "" {
w.Write(config.FormHTML)
return
}
storage, ok := o.Storage().(DeviceCodeStorage) storage, err := assertDeviceStorage(o.Storage())
if !ok { if err != nil {
// unimplemented error?
}
if err := storage.ReleaseDeviceAccessToken(r.Context(), userCode); err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
ctx := r.Context()
token, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, data.AccesssToken, o.AccessTokenVerifier(ctx))
if err != nil {
if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil {
err = se
}
RequestError(w, r, err)
return
}
if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.Subject); err != nil {
RequestError(w, r, err)
return
}
if data.RedirectURL != "" {
http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther)
}
fmt.Fprintln(w, "Authorization successfull, please return to your device") fmt.Fprintln(w, "Authorization successfull, please return to your device")
} }
func UserCodeFromRequest(r *http.Request, key string) (string, error) { func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return "", oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
} }
return r.Form.Get(key), nil req := new(UserCodeFormData)
if err := decoder.Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err)
}
if req.AccesssToken == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form")
}
if req.UserCode == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form")
}
return req, nil
} }

View file

@ -154,32 +154,42 @@ type EndSessionRequest struct {
var ErrDuplicateUserCode = errors.New("user code already exists") var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceCodeStorage interface { type DeviceAuthorizationState struct {
Scopes []string
Expires time.Time
Completed bool
Subject string
Denied bool
}
type DeviceAuthorizationStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database. // StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique. // User code will be used by the user to complete the login flow and must be unique.
// ErrDuplicateUserCode signals the caller should try again with a new code. // ErrDuplicateUserCode signals the caller should try again with a new code.
// //
// Note that user codes are low entropy keys and when many exist in the // Note that user codes are low entropy keys and when many exist in the
// database, the change for collisions increases. Therefore implementers // database, the change for collisions increases. Therefore implementers
// of this interface must make sure that user codes of completed or expired // of this interface must make sure that user codes of expired authentication flows are purged,
// authentication flows are deleted. // after some time.
StoreDeviceAuthorizationRequest(ctx context.Context, req *oidc.DeviceAuthorizationRequest, deviceCode, userCode string) error StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, scopes []string) error
// DeviceAccessPoll is called by the device untill the authorization flow is // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
// completed or expired. // The method is polled untill the the authorization is eighter Completed, Expired or Denied.
// GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
// The following errors are defined for the Device Authorization workflow,
// that can be returned by this method:
// - oidc.ErrAuthorizationPending should be returned on each poll, while the flow is not completed by the user.
// - oidc.ErrSlowDown signals to the device that the polling interval is to be increased by 5 seconds.
// - oidc.ErrAccessDenied when the authorization request is denied.
// - oidc.ErrExpiredToken when the device code has expired.
//
// A token should be returned once the authorization flow is completed
// by the user.
DeviceAccessPoll(ctx context.Context, deviceCode string) (Client, error)
// ReleaseDeviceAccessToken releases DeviceAccessPoll to return the Access Token, // CompleteDeviceAuthorization marks a device authorization entry as Completed,
// destined for a user code. // identified by userCode. The Subject is added to the state, so that
ReleaseDeviceAccessToken(ctx context.Context, userCode string) error // GetDeviceAuthorizatonState can use it to create a new Access Token.
CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
// DenyDeviceAuthorization marks a device authorization entry as Denied.
DenyDeviceAuthorization(ctx context.Context, userCode string) error
}
func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
storage, ok := s.(DeviceAuthorizationStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
}
return storage, nil
} }