implement RFC 8628: Device authorization grant

WIP

Related #264
This commit is contained in:
Tim Möhlmann 2023-02-22 20:11:42 +01:00
parent 8e298791d7
commit 671b13b9c6
15 changed files with 693 additions and 16 deletions

View file

@ -27,6 +27,8 @@ type Configuration interface {
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
UserCodeFormEndpoint() Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool
@ -36,6 +38,7 @@ type Configuration interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
@ -44,6 +47,7 @@ type Configuration interface {
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag
DeviceAuthorization() DeviceAuthorizationConfig
}
type IssuerFromRequest func(r *http.Request) string

232
pkg/op/device.go Normal file
View file

@ -0,0 +1,232 @@
package op
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"math/big"
"net/http"
"net/url"
"strings"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
Lifetime int
PollInterval int
UserCode UserCodeConfig
}
type UserCodeConfig struct {
CharSet string
CharAmount int
DashInterval int
QueryKey string
FormHTML []byte
}
const (
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
CharSetDigits = "0123456789"
)
var (
UserCodeBase20 = UserCodeConfig{
CharSet: CharSetBase20,
CharAmount: 8,
DashInterval: 4,
QueryKey: "user_code",
}
UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits,
CharAmount: 9,
DashInterval: 3,
QueryKey: "user_code",
}
)
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
DeviceAuthorization(w, r, o)
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
storage, ok := o.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
req, err := ParseDeviceCodeRequest(r, o.Decoder())
if err != nil {
RequestError(w, r, err)
return
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
RequestError(w, r, err)
return
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
if err != nil {
RequestError(w, r, err)
return
}
err = storage.StoreDeviceAuthorizationRequest(r.Context(), req, deviceCode, userCode)
if err != nil {
RequestError(w, r, err)
return
}
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: endpoint,
}
if key := config.UserCode.QueryKey; key != "" {
vals := make(url.Values, 1)
vals.Set(key, userCode)
response.VerificationURIComplete = strings.Join([]string{endpoint, vals.Encode()}, "?")
}
httphelper.MarshalJSON(w, response)
}
func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) {
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
devReq := new(oidc.DeviceAuthorizationRequest)
if err := decoder.Decode(devReq, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse dev auth request").WithParent(err)
}
return devReq, nil
}
// 16 bytes gives 128 bit of entropy.
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
var buf strings.Builder
if dashInterval > 0 {
buf.Grow(charAmount + charAmount/dashInterval - 1)
} else {
buf.Grow(charAmount)
}
max := big.NewInt(int64(len(charSet)))
for i := 0; i < charAmount; i++ {
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
buf.WriteByte('-')
}
bi, err := rand.Int(rand.Reader, max)
if err != nil {
return "", fmt.Errorf("%w getting entropy for user code", err)
}
buf.WriteRune(charSet[int(bi.Int64())])
}
return buf.String(), nil
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
RequestError(w, r, err)
}
storage, ok := exchanger.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
client, err := storage.DeviceAccessPoll(r.Context(), req.DeviceCode)
if err != nil {
RequestError(w, r, err)
}
resp, err := CreateDeviceTokenResponse(r.Context(), req, exchanger, client)
if err != nil {
RequestError(w, r, err)
return
}
httphelper.MarshalJSON(w, resp)
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
UserCodeForm(w, r, o)
}
}
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
// check cookie, or what??
config := o.DeviceAuthorization().UserCode
userCode, err := UserCodeFromRequest(r, config.QueryKey)
if err != nil {
RequestError(w, r, err)
return
}
if userCode == "" {
w.Write(config.FormHTML)
return
}
storage, ok := o.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
if err := storage.ReleaseDeviceAccessToken(r.Context(), userCode); err != nil {
RequestError(w, r, err)
return
}
fmt.Fprintln(w, "Authorization successfull, please return to your device")
}
func UserCodeFromRequest(r *http.Request, key string) (string, error) {
if err := r.ParseForm(); err != nil {
return "", oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
return r.Form.Get(key), nil
}

204
pkg/op/device_test.go Normal file
View file

@ -0,0 +1,204 @@
package op
import (
"crypto/rand"
"encoding/base64"
"io"
mr "math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type errReader struct {
}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
func runWithRandReader(r io.Reader, f func()) {
originalReader := rand.Reader
rand.Reader = r
defer func() {
rand.Reader = originalReader
}()
f()
}
func TestNewDeviceCode(t *testing.T) {
t.Run("reader error", func(t *testing.T) {
runWithRandReader(errReader{}, func() {
_, err := NewDeviceCode(16)
require.Error(t, err)
})
})
t.Run("dirrent lengths, rand reader", func(t *testing.T) {
for i := 1; i <= 32; i++ {
got, err := NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
})
}
func TestNewUserCode(t *testing.T) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
want string
wantErr bool
}{
{
name: "reader error",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: errReader{},
wantErr: true,
},
{
name: "base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
want: "XKCD-HTTD",
},
{
name: "digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
want: "271-256-225",
},
{
name: "no dashes",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
},
reader: mr.New(mr.NewSource(1)),
want: "271256225",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runWithRandReader(tt.reader, func() {
got, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
if tt.wantErr {
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
})
}
t.Run("crypto/rand", func(t *testing.T) {
const testN = 100000
for _, c := range []UserCodeConfig{UserCodeBase20, UserCodeDigits} {
t.Run(c.CharSet, func(t *testing.T) {
results := make(map[string]int)
for i := 0; i < testN; i++ {
code, err := NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
require.NoError(t, err)
results[code]++
}
t.Log(results)
var duplicates int
for code, count := range results {
assert.Less(t, count, 3, code)
if count == 2 {
duplicates++
}
}
})
}
})
}
func BenchmarkNewUserCode(b *testing.B) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
}{
{
name: "math rand, base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "math rand, digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "crypto rand, base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: rand.Reader,
},
{
name: "crypto rand, digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: rand.Reader,
},
}
for _, tt := range tests {
runWithRandReader(tt.reader, func() {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
require.NoError(b, err)
}
})
})
}
}

View file

@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType {
if c.GrantTypeJWTAuthorizationSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeBearer)
}
if c.GrantTypeDeviceCodeSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
}
return grantTypes
}

View file

@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
}
// DeviceAuthorization mocks base method.
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorization")
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
return ret0
}
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
m.ctrl.T.Helper()
@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
}
// GrantTypeDeviceCodeSupported mocks base method.
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
}
// GrantTypeJWTAuthorizationSupported mocks base method.
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
m.ctrl.T.Helper()
@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
}
// UserCodeFormEndpoint mocks base method.
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCodeFormEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint.
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint))
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper()

View file

@ -27,17 +27,21 @@ const (
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization"
defaultUserCodeFormEndpoint = "/device"
)
var (
DefaultEndpoints = &endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint),
}
defaultCORSOptions = cors.Options{
@ -95,6 +99,8 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o))
router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o))
return router
}
@ -121,14 +127,16 @@ type Config struct {
}
type endpoints struct {
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
UserCodeForm Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -242,6 +250,14 @@ func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession
}
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) UserCodeFormEndpoint() Endpoint {
return o.endpoints.UserCodeForm
}
func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI
}
@ -275,6 +291,10 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true
}
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
return true
}
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
@ -308,6 +328,10 @@ func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales
}
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
return DeviceAuthorizationConfig{}
}
func (o *Provider) Storage() Storage {
return o.storage
}

View file

@ -151,3 +151,35 @@ type EndSessionRequest struct {
ClientID string
RedirectURI string
}
var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceCodeStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique.
// ErrDuplicateUserCode signals the caller should try again with a new code.
//
// Note that user codes are low entropy keys and when many exist in the
// database, the change for collisions increases. Therefore implementers
// of this interface must make sure that user codes of completed or expired
// authentication flows are deleted.
StoreDeviceAuthorizationRequest(ctx context.Context, req *oidc.DeviceAuthorizationRequest, deviceCode, userCode string) error
// DeviceAccessPoll is called by the device untill the authorization flow is
// completed or expired.
//
// The following errors are defined for the Device Authorization workflow,
// that can be returned by this method:
// - oidc.ErrAuthorizationPending should be returned on each poll, while the flow is not completed by the user.
// - oidc.ErrSlowDown signals to the device that the polling interval is to be increased by 5 seconds.
// - oidc.ErrAccessDenied when the authorization request is denied.
// - oidc.ErrExpiredToken when the device code has expired.
//
// A token should be returned once the authorization flow is completed
// by the user.
DeviceAccessPoll(ctx context.Context, deviceCode string) (Client, error)
// ReleaseDeviceAccessToken releases DeviceAccessPoll to return the Access Token,
// destined for a user code.
ReleaseDeviceAccessToken(ctx context.Context, userCode string) error
}

View file

@ -19,6 +19,7 @@ type Exchanger interface {
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier
}
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ClientCredentialsExchange(w, r, exchanger)
return
}
case string(oidc.GrantTypeDeviceCode):
if exchanger.GrantTypeDeviceCodeSupported() {
DeviceAccessToken(w, r, exchanger)
return
}
case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return