First verify if the client is authenticated. Then the state of the device authorization. If all is good, we take the Client from Storage.
334 lines
8.5 KiB
Go
334 lines
8.5 KiB
Go
package op
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
|
)
|
|
|
|
type DeviceAuthorizationConfig struct {
|
|
Lifetime int
|
|
PollInterval int
|
|
UserFormURL string
|
|
UserCode UserCodeConfig
|
|
}
|
|
|
|
type UserCodeConfig struct {
|
|
CharSet string
|
|
CharAmount int
|
|
DashInterval int
|
|
}
|
|
|
|
const (
|
|
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
|
|
CharSetDigits = "0123456789"
|
|
)
|
|
|
|
var (
|
|
UserCodeBase20 = UserCodeConfig{
|
|
CharSet: CharSetBase20,
|
|
CharAmount: 8,
|
|
DashInterval: 4,
|
|
}
|
|
UserCodeDigits = UserCodeConfig{
|
|
CharSet: CharSetDigits,
|
|
CharAmount: 9,
|
|
DashInterval: 3,
|
|
}
|
|
)
|
|
|
|
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if err := DeviceAuthorization(w, r, o); err != nil {
|
|
RequestError(w, r, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
|
storage, err := assertDeviceStorage(o.Storage())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req, err := ParseDeviceCodeRequest(r, o)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
config := o.DeviceAuthorization()
|
|
|
|
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
|
|
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
response := &oidc.DeviceAuthorizationResponse{
|
|
DeviceCode: deviceCode,
|
|
UserCode: userCode,
|
|
VerificationURI: config.UserFormURL,
|
|
}
|
|
|
|
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
|
|
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
|
|
|
|
httphelper.MarshalJSON(w, response)
|
|
return nil
|
|
}
|
|
|
|
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
|
clientID, _, err := ClientIDFromRequest(r, o)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req := new(oidc.DeviceAuthorizationRequest)
|
|
if err := o.Decoder().Decode(req, r.Form); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
|
|
}
|
|
req.ClientID = clientID
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// 16 bytes gives 128 bit of entropy.
|
|
// results in a 22 character base64 encoded string.
|
|
const RecommendedDeviceCodeBytes = 16
|
|
|
|
func NewDeviceCode(nBytes int) (string, error) {
|
|
bytes := make([]byte, nBytes)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("%w getting entropy for device code", err)
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
|
}
|
|
|
|
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
|
|
var buf strings.Builder
|
|
if dashInterval > 0 {
|
|
buf.Grow(charAmount + charAmount/dashInterval - 1)
|
|
} else {
|
|
buf.Grow(charAmount)
|
|
}
|
|
|
|
max := big.NewInt(int64(len(charSet)))
|
|
|
|
for i := 0; i < charAmount; i++ {
|
|
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
|
|
buf.WriteByte('-')
|
|
}
|
|
|
|
bi, err := rand.Int(rand.Reader, max)
|
|
if err != nil {
|
|
return "", fmt.Errorf("%w getting entropy for user code", err)
|
|
}
|
|
|
|
buf.WriteRune(charSet[int(bi.Int64())])
|
|
}
|
|
|
|
return buf.String(), nil
|
|
}
|
|
|
|
type deviceAccessTokenRequest struct {
|
|
subject string
|
|
audience []string
|
|
scopes []string
|
|
}
|
|
|
|
func (r *deviceAccessTokenRequest) GetSubject() string {
|
|
return r.subject
|
|
}
|
|
|
|
func (r *deviceAccessTokenRequest) GetAudience() []string {
|
|
return r.audience
|
|
}
|
|
|
|
func (r *deviceAccessTokenRequest) GetScopes() []string {
|
|
return r.scopes
|
|
}
|
|
|
|
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
|
RequestError(w, r, err)
|
|
}
|
|
}
|
|
|
|
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
|
|
// use a limited context timeout shorter as the default
|
|
// poll interval of 5 seconds.
|
|
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
|
|
defer cancel()
|
|
r = r.WithContext(ctx)
|
|
|
|
clientID, authenticated, err := ClientIDFromRequest(r, exchanger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !authenticated {
|
|
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
|
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
|
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
|
}
|
|
}
|
|
|
|
tokenRequest := &deviceAccessTokenRequest{
|
|
subject: state.Subject,
|
|
audience: []string{clientID},
|
|
scopes: state.Scopes,
|
|
}
|
|
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
httphelper.MarshalJSON(w, resp)
|
|
return nil
|
|
}
|
|
|
|
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
|
req := new(oidc.DeviceAccessTokenRequest)
|
|
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
|
return nil, err
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
|
storage, err := assertDeviceStorage(exchanger.Storage())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, oidc.ErrSlowDown().WithParent(err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if state.Denied {
|
|
return state, oidc.ErrAccessDenied()
|
|
}
|
|
if state.Completed {
|
|
return state, nil
|
|
}
|
|
if time.Now().After(state.Expires) {
|
|
return state, oidc.ErrExpiredDeviceCode()
|
|
}
|
|
return state, oidc.ErrAuthorizationPending()
|
|
}
|
|
|
|
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
|
|
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
|
|
|
|
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &oidc.AccessTokenResponse{
|
|
AccessToken: accessToken,
|
|
RefreshToken: refreshToken,
|
|
TokenType: oidc.BearerToken,
|
|
ExpiresIn: uint64(validity.Seconds()),
|
|
}, nil
|
|
}
|
|
|
|
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
UserCodeForm(w, r, o)
|
|
}
|
|
}
|
|
|
|
type UserCodeFormData struct {
|
|
AccesssToken string `schema:"access_token"`
|
|
UserCode string `schema:"user_code"`
|
|
RedirectURL string `schema:"redirect_url"`
|
|
}
|
|
|
|
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
|
|
data, err := ParseUserCodeFormData(r, o.Decoder())
|
|
if err != nil {
|
|
RequestError(w, r, err)
|
|
return
|
|
}
|
|
|
|
storage, err := assertDeviceStorage(o.Storage())
|
|
if err != nil {
|
|
RequestError(w, r, err)
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx))
|
|
if err != nil {
|
|
if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil {
|
|
err = se
|
|
}
|
|
RequestError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil {
|
|
RequestError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if data.RedirectURL != "" {
|
|
http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther)
|
|
}
|
|
|
|
fmt.Fprintln(w, "Authorization successfull, please return to your device")
|
|
}
|
|
|
|
func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) {
|
|
if err := r.ParseForm(); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
|
|
}
|
|
|
|
req := new(UserCodeFormData)
|
|
if err := decoder.Decode(req, r.Form); err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err)
|
|
}
|
|
if req.AccesssToken == "" {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form")
|
|
}
|
|
if req.UserCode == "" {
|
|
return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form")
|
|
}
|
|
|
|
return req, nil
|
|
}
|