zitadel-oidc/pkg/op/device.go
Tim Möhlmann 65cd4528e4 don't obtain a Client from storage on each poll
First verify if the client is authenticated.
Then the state of the device authorization.
If all is good, we take the Client from Storage.
2023-02-26 18:23:43 +01:00

334 lines
8.5 KiB
Go

package op
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
Lifetime int
PollInterval int
UserFormURL string
UserCode UserCodeConfig
}
type UserCodeConfig struct {
CharSet string
CharAmount int
DashInterval int
}
const (
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
CharSetDigits = "0123456789"
)
var (
UserCodeBase20 = UserCodeConfig{
CharSet: CharSetBase20,
CharAmount: 8,
DashInterval: 4,
}
UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits,
CharAmount: 9,
DashInterval: 3,
}
)
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
}
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o)
if err != nil {
return err
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return err
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
if err != nil {
return err
}
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
if err != nil {
return err
}
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: config.UserFormURL,
}
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
httphelper.MarshalJSON(w, response)
return nil
}
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
clientID, _, err := ClientIDFromRequest(r, o)
if err != nil {
return nil, err
}
req := new(oidc.DeviceAuthorizationRequest)
if err := o.Decoder().Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
}
req.ClientID = clientID
return req, nil
}
// 16 bytes gives 128 bit of entropy.
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
var buf strings.Builder
if dashInterval > 0 {
buf.Grow(charAmount + charAmount/dashInterval - 1)
} else {
buf.Grow(charAmount)
}
max := big.NewInt(int64(len(charSet)))
for i := 0; i < charAmount; i++ {
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
buf.WriteByte('-')
}
bi, err := rand.Int(rand.Reader, max)
if err != nil {
return "", fmt.Errorf("%w getting entropy for user code", err)
}
buf.WriteRune(charSet[int(bi.Int64())])
}
return buf.String(), nil
}
type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}
func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}
func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}
func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
if err := deviceAccessToken(w, r, exchanger); err != nil {
RequestError(w, r, err)
}
}
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
defer cancel()
r = r.WithContext(ctx)
clientID, authenticated, err := ClientIDFromRequest(r, exchanger)
if err != nil {
return err
}
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
if err != nil {
return err
}
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil {
return err
}
client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
if err != nil {
return err
}
if !authenticated {
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription(fmt.Sprintf("required client auth method: %s", m))
}
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil {
return err
}
httphelper.MarshalJSON(w, resp)
return nil
}
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
return nil, err
}
return req, nil
}
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil {
return nil, err
}
state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)
}
if err != nil {
return nil, err
}
if state.Denied {
return state, oidc.ErrAccessDenied()
}
if state.Completed {
return state, nil
}
if time.Now().After(state.Expires) {
return state, oidc.ErrExpiredDeviceCode()
}
return state, oidc.ErrAuthorizationPending()
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
UserCodeForm(w, r, o)
}
}
type UserCodeFormData struct {
AccesssToken string `schema:"access_token"`
UserCode string `schema:"user_code"`
RedirectURL string `schema:"redirect_url"`
}
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
data, err := ParseUserCodeFormData(r, o.Decoder())
if err != nil {
RequestError(w, r, err)
return
}
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
RequestError(w, r, err)
return
}
ctx := r.Context()
token, err := VerifyAccessToken(ctx, data.AccesssToken, o.AccessTokenVerifier(ctx))
if err != nil {
if se := storage.DenyDeviceAuthorization(ctx, data.UserCode); se != nil {
err = se
}
RequestError(w, r, err)
return
}
if err := storage.CompleteDeviceAuthorization(ctx, data.UserCode, token.GetSubject()); err != nil {
RequestError(w, r, err)
return
}
if data.RedirectURL != "" {
http.Redirect(w, r, data.RedirectURL, http.StatusSeeOther)
}
fmt.Fprintln(w, "Authorization successfull, please return to your device")
}
func ParseUserCodeFormData(r *http.Request, decoder httphelper.Decoder) (*UserCodeFormData, error) {
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
req := new(UserCodeFormData)
if err := decoder.Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse user code form").WithParent(err)
}
if req.AccesssToken == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("access_token missing in form")
}
if req.UserCode == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("user_code missing in form")
}
return req, nil
}