implement RFC 8628: Device authorization grant

WIP

Related #264
This commit is contained in:
Tim Möhlmann 2023-02-22 20:11:42 +01:00
parent 8e298791d7
commit 671b13b9c6
15 changed files with 693 additions and 16 deletions

View file

@ -186,3 +186,20 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
IssuedAt: oidc.Time(iat), IssuedAt: oidc.Time(iat),
}, signer) }, signer)
} }
type DeviceAuthorizationCaller interface {
GetDeviceCodeEndpoint() string
HttpClient() *http.Client
}
func CallDeviceAuthorizationEndpoint(request interface{}, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) {
req, err := httphelper.FormRequest(caller.GetDeviceCodeEndpoint(), request, Encoder, nil)
if err != nil {
return nil, err
}
resp := new(oidc.DeviceAuthorizationResponse)
if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

20
pkg/client/rp/device.go Normal file
View file

@ -0,0 +1,20 @@
package rp
import (
"github.com/zitadel/oidc/v2/pkg/client"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
func DeviceAuthorization(clientID string, scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: scopes,
ClientID: clientID,
}
return client.CallDeviceAuthorizationEndpoint(req, rp)
}
/*
func DeviceAccessToken() (*oauth2.Token, error) {
req := &oidc.DeviceAccessTokenRequest{}
}
*/

View file

@ -59,6 +59,8 @@ type RelyingParty interface {
// UserinfoEndpoint returns the userinfo // UserinfoEndpoint returns the userinfo
UserinfoEndpoint() string UserinfoEndpoint() string
GetDeviceCodeEndpoint() string
// IDTokenVerifier returns the verifier interface used for oidc id_token verification // IDTokenVerifier returns the verifier interface used for oidc id_token verification
IDTokenVerifier() IDTokenVerifier IDTokenVerifier() IDTokenVerifier
// ErrorHandler returns the handler used for callback errors // ErrorHandler returns the handler used for callback errors
@ -121,6 +123,10 @@ func (rp *relyingParty) UserinfoEndpoint() string {
return rp.endpoints.UserinfoURL return rp.endpoints.UserinfoURL
} }
func (rp *relyingParty) GetDeviceCodeEndpoint() string {
return rp.endpoints.DeviceCodeURL
}
func (rp *relyingParty) GetEndSessionEndpoint() string { func (rp *relyingParty) GetEndSessionEndpoint() string {
return rp.endpoints.EndSessionURL return rp.endpoints.EndSessionURL
} }
@ -500,6 +506,7 @@ type Endpoints struct {
JKWsURL string JKWsURL string
EndSessionURL string EndSessionURL string
RevokeURL string RevokeURL string
DeviceCodeURL string
} }
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
@ -514,6 +521,7 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
JKWsURL: discoveryConfig.JwksURI, JKWsURL: discoveryConfig.JwksURI,
EndSessionURL: discoveryConfig.EndSessionEndpoint, EndSessionURL: discoveryConfig.EndSessionEndpoint,
RevokeURL: discoveryConfig.RevocationEndpoint, RevokeURL: discoveryConfig.RevocationEndpoint,
DeviceCodeURL: discoveryConfig.DeviceAuthorizationEndpoint,
} }
} }

View file

@ -0,0 +1,31 @@
package oidc
// DeviceAuthorizationRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
// 3.1 Device Authorization Request.
type DeviceAuthorizationRequest struct {
Scopes SpaceDelimitedArray `schema:"scope"`
ClientID string `schema:"client_id"`
}
// DeviceAuthorizationResponse implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
// 3.2. Device Authorization Response.
type DeviceAuthorizationResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval,omitempty"`
}
// DeviceAccessTokenRequest implements
// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
// Device Access Token Request.
type DeviceAccessTokenRequest struct {
JWTTokenRequest
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"` // required, how??
}

View file

@ -30,6 +30,8 @@ type DiscoveryConfiguration struct {
// EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP. // EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
// CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client. // CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
CheckSessionIframe string `json:"check_session_iframe,omitempty"` CheckSessionIframe string `json:"check_session_iframe,omitempty"`

View file

@ -18,6 +18,14 @@ const (
InteractionRequired errorType = "interaction_required" InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required" LoginRequired errorType = "login_required"
RequestNotSupported errorType = "request_not_supported" RequestNotSupported errorType = "request_not_supported"
// Additional error codes as defined in
// https://www.rfc-editor.org/rfc/rfc8628#section-3.5
// Device Access Token Response
AuthorizationPending errorType = "authorization_pending"
SlowDown errorType = "slow_down"
AccessDenied errorType = "access_denied"
ExpiredToken errorType = "expired_token"
) )
var ( var (
@ -77,6 +85,32 @@ var (
ErrorType: RequestNotSupported, ErrorType: RequestNotSupported,
} }
} }
// Device Access Token errors:
ErrAuthorizationPending = func() *Error {
return &Error{
ErrorType: AuthorizationPending,
Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
}
}
ErrSlowDown = func() *Error {
return &Error{
ErrorType: SlowDown,
Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
}
}
ErrAccessDenied = func() *Error {
return &Error{
ErrorType: AccessDenied,
Description: "The authorization request was denied.",
}
}
ErrExpiredToken = func() *Error {
return &Error{
ErrorType: ExpiredToken,
Description: "The \"device_code\" has expired.",
}
}
) )
type Error struct { type Error struct {

View file

@ -27,6 +27,9 @@ const (
// GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code // GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
GrantTypeImplicit GrantType = "implicit" GrantTypeImplicit GrantType = "implicit"
// GrantTypeDeviceCode
GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
// ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer` // ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
// used for the OAuth JWT Profile Client Authentication // used for the OAuth JWT Profile Client Authentication
ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
@ -35,7 +38,7 @@ const (
var AllGrantTypes = []GrantType{ var AllGrantTypes = []GrantType{
GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials, GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit, GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
ClientAssertionTypeJWTAssertion, GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
} }
type GrantType string type GrantType string

View file

@ -27,6 +27,8 @@ type Configuration interface {
RevocationEndpoint() Endpoint RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
UserCodeFormEndpoint() Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool
@ -36,6 +38,7 @@ type Configuration interface {
GrantTypeTokenExchangeSupported() bool GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool RevocationAuthMethodPrivateKeyJWTSupported() bool
@ -44,6 +47,7 @@ type Configuration interface {
RequestObjectSigningAlgorithmsSupported() []string RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag SupportedUILocales() []language.Tag
DeviceAuthorization() DeviceAuthorizationConfig
} }
type IssuerFromRequest func(r *http.Request) string type IssuerFromRequest func(r *http.Request) string

232
pkg/op/device.go Normal file
View file

@ -0,0 +1,232 @@
package op
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"math/big"
"net/http"
"net/url"
"strings"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
)
type DeviceAuthorizationConfig struct {
Lifetime int
PollInterval int
UserCode UserCodeConfig
}
type UserCodeConfig struct {
CharSet string
CharAmount int
DashInterval int
QueryKey string
FormHTML []byte
}
const (
CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
CharSetDigits = "0123456789"
)
var (
UserCodeBase20 = UserCodeConfig{
CharSet: CharSetBase20,
CharAmount: 8,
DashInterval: 4,
QueryKey: "user_code",
}
UserCodeDigits = UserCodeConfig{
CharSet: CharSetDigits,
CharAmount: 9,
DashInterval: 3,
QueryKey: "user_code",
}
)
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
DeviceAuthorization(w, r, o)
}
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
storage, ok := o.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
req, err := ParseDeviceCodeRequest(r, o.Decoder())
if err != nil {
RequestError(w, r, err)
return
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
RequestError(w, r, err)
return
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
if err != nil {
RequestError(w, r, err)
return
}
err = storage.StoreDeviceAuthorizationRequest(r.Context(), req, deviceCode, userCode)
if err != nil {
RequestError(w, r, err)
return
}
endpoint := o.UserCodeFormEndpoint().Absolute(IssuerFromContext(r.Context()))
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: endpoint,
}
if key := config.UserCode.QueryKey; key != "" {
vals := make(url.Values, 1)
vals.Set(key, userCode)
response.VerificationURIComplete = strings.Join([]string{endpoint, vals.Encode()}, "?")
}
httphelper.MarshalJSON(w, response)
}
func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) {
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
devReq := new(oidc.DeviceAuthorizationRequest)
if err := decoder.Decode(devReq, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse dev auth request").WithParent(err)
}
return devReq, nil
}
// 16 bytes gives 128 bit of entropy.
// results in a 22 character base64 encoded string.
const RecommendedDeviceCodeBytes = 16
func NewDeviceCode(nBytes int) (string, error) {
bytes := make([]byte, nBytes)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("%w getting entropy for device code", err)
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
var buf strings.Builder
if dashInterval > 0 {
buf.Grow(charAmount + charAmount/dashInterval - 1)
} else {
buf.Grow(charAmount)
}
max := big.NewInt(int64(len(charSet)))
for i := 0; i < charAmount; i++ {
if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
buf.WriteByte('-')
}
bi, err := rand.Int(rand.Reader, max)
if err != nil {
return "", fmt.Errorf("%w getting entropy for user code", err)
}
buf.WriteRune(charSet[int(bi.Int64())])
}
return buf.String(), nil
}
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
req := new(oidc.DeviceAccessTokenRequest)
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
RequestError(w, r, err)
}
storage, ok := exchanger.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
client, err := storage.DeviceAccessPoll(r.Context(), req.DeviceCode)
if err != nil {
RequestError(w, r, err)
}
resp, err := CreateDeviceTokenResponse(r.Context(), req, exchanger, client)
if err != nil {
RequestError(w, r, err)
return
}
httphelper.MarshalJSON(w, resp)
}
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
tokenType := AccessTokenTypeBearer // not sure if this is the correct type?
accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
if err != nil {
return nil, err
}
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}
func userCodeFormHandler(o OpenIDProvider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
UserCodeForm(w, r, o)
}
}
func UserCodeForm(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
// check cookie, or what??
config := o.DeviceAuthorization().UserCode
userCode, err := UserCodeFromRequest(r, config.QueryKey)
if err != nil {
RequestError(w, r, err)
return
}
if userCode == "" {
w.Write(config.FormHTML)
return
}
storage, ok := o.Storage().(DeviceCodeStorage)
if !ok {
// unimplemented error?
}
if err := storage.ReleaseDeviceAccessToken(r.Context(), userCode); err != nil {
RequestError(w, r, err)
return
}
fmt.Fprintln(w, "Authorization successfull, please return to your device")
}
func UserCodeFromRequest(r *http.Request, key string) (string, error) {
if err := r.ParseForm(); err != nil {
return "", oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
return r.Form.Get(key), nil
}

204
pkg/op/device_test.go Normal file
View file

@ -0,0 +1,204 @@
package op
import (
"crypto/rand"
"encoding/base64"
"io"
mr "math/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type errReader struct {
}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
func runWithRandReader(r io.Reader, f func()) {
originalReader := rand.Reader
rand.Reader = r
defer func() {
rand.Reader = originalReader
}()
f()
}
func TestNewDeviceCode(t *testing.T) {
t.Run("reader error", func(t *testing.T) {
runWithRandReader(errReader{}, func() {
_, err := NewDeviceCode(16)
require.Error(t, err)
})
})
t.Run("dirrent lengths, rand reader", func(t *testing.T) {
for i := 1; i <= 32; i++ {
got, err := NewDeviceCode(i)
require.NoError(t, err)
assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
}
})
}
func TestNewUserCode(t *testing.T) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
want string
wantErr bool
}{
{
name: "reader error",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: errReader{},
wantErr: true,
},
{
name: "base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
want: "XKCD-HTTD",
},
{
name: "digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
want: "271-256-225",
},
{
name: "no dashes",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
},
reader: mr.New(mr.NewSource(1)),
want: "271256225",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
runWithRandReader(tt.reader, func() {
got, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
if tt.wantErr {
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
})
}
t.Run("crypto/rand", func(t *testing.T) {
const testN = 100000
for _, c := range []UserCodeConfig{UserCodeBase20, UserCodeDigits} {
t.Run(c.CharSet, func(t *testing.T) {
results := make(map[string]int)
for i := 0; i < testN; i++ {
code, err := NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
require.NoError(t, err)
results[code]++
}
t.Log(results)
var duplicates int
for code, count := range results {
assert.Less(t, count, 3, code)
if count == 2 {
duplicates++
}
}
})
}
})
}
func BenchmarkNewUserCode(b *testing.B) {
type args struct {
charset []rune
charAmount int
dashInterval int
}
tests := []struct {
name string
args args
reader io.Reader
}{
{
name: "math rand, base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "math rand, digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: mr.New(mr.NewSource(1)),
},
{
name: "crypto rand, base20",
args: args{
charset: []rune(CharSetBase20),
charAmount: 8,
dashInterval: 4,
},
reader: rand.Reader,
},
{
name: "crypto rand, digits",
args: args{
charset: []rune(CharSetDigits),
charAmount: 9,
dashInterval: 3,
},
reader: rand.Reader,
},
}
for _, tt := range tests {
runWithRandReader(tt.reader, func() {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
require.NoError(b, err)
}
})
})
}
}

View file

@ -44,6 +44,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer), RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config), ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config), ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config), GrantTypesSupported: GrantTypes(config),
@ -92,6 +93,9 @@ func GrantTypes(c Configuration) []oidc.GrantType {
if c.GrantTypeJWTAuthorizationSupported() { if c.GrantTypeJWTAuthorizationSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeBearer) grantTypes = append(grantTypes, oidc.GrantTypeBearer)
} }
if c.GrantTypeDeviceCodeSupported() {
grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
}
return grantTypes return grantTypes
} }

View file

@ -92,6 +92,34 @@ func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
} }
// DeviceAuthorization mocks base method.
func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorization")
ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
return ret0
}
// DeviceAuthorization indicates an expected call of DeviceAuthorization.
func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
}
// EndSessionEndpoint mocks base method. // EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -120,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
} }
// GrantTypeDeviceCodeSupported mocks base method.
func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
}
// GrantTypeJWTAuthorizationSupported mocks base method. // GrantTypeJWTAuthorizationSupported mocks base method.
func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool { func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -358,6 +400,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
} }
// UserCodeFormEndpoint mocks base method.
func (m *MockConfiguration) UserCodeFormEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserCodeFormEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// UserCodeFormEndpoint indicates an expected call of UserCodeFormEndpoint.
func (mr *MockConfigurationMockRecorder) UserCodeFormEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserCodeFormEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserCodeFormEndpoint))
}
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -27,17 +27,21 @@ const (
defaultRevocationEndpoint = "revoke" defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session" defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys" defaultKeysEndpoint = "keys"
defaultDeviceAuthzEndpoint = "/device_authorization"
defaultUserCodeFormEndpoint = "/device"
) )
var ( var (
DefaultEndpoints = &endpoints{ DefaultEndpoints = &endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint), Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint), Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint), Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint), Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint), EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint),
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
UserCodeForm: NewEndpoint(defaultUserCodeFormEndpoint),
} }
defaultCORSOptions = cors.Options{ defaultCORSOptions = cors.Options{
@ -95,6 +99,8 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), deviceAuthorizationHandler(o))
router.HandleFunc(o.UserCodeFormEndpoint().Relative(), userCodeFormHandler(o))
return router return router
} }
@ -121,14 +127,16 @@ type Config struct {
} }
type endpoints struct { type endpoints struct {
Authorization Endpoint Authorization Endpoint
Token Endpoint Token Endpoint
Introspection Endpoint Introspection Endpoint
Userinfo Endpoint Userinfo Endpoint
Revocation Endpoint Revocation Endpoint
EndSession Endpoint EndSession Endpoint
CheckSessionIframe Endpoint CheckSessionIframe Endpoint
JwksURI Endpoint JwksURI Endpoint
DeviceAuthorization Endpoint
UserCodeForm Endpoint
} }
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -242,6 +250,14 @@ func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession return o.endpoints.EndSession
} }
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) UserCodeFormEndpoint() Endpoint {
return o.endpoints.UserCodeForm
}
func (o *Provider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI return o.endpoints.JwksURI
} }
@ -275,6 +291,10 @@ func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true return true
} }
func (o *Provider) GrantTypeDeviceCodeSupported() bool {
return true
}
func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true return true
} }
@ -308,6 +328,10 @@ func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales return o.config.SupportedUILocales
} }
func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
return DeviceAuthorizationConfig{}
}
func (o *Provider) Storage() Storage { func (o *Provider) Storage() Storage {
return o.storage return o.storage
} }

View file

@ -151,3 +151,35 @@ type EndSessionRequest struct {
ClientID string ClientID string
RedirectURI string RedirectURI string
} }
var ErrDuplicateUserCode = errors.New("user code already exists")
type DeviceCodeStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique.
// ErrDuplicateUserCode signals the caller should try again with a new code.
//
// Note that user codes are low entropy keys and when many exist in the
// database, the change for collisions increases. Therefore implementers
// of this interface must make sure that user codes of completed or expired
// authentication flows are deleted.
StoreDeviceAuthorizationRequest(ctx context.Context, req *oidc.DeviceAuthorizationRequest, deviceCode, userCode string) error
// DeviceAccessPoll is called by the device untill the authorization flow is
// completed or expired.
//
// The following errors are defined for the Device Authorization workflow,
// that can be returned by this method:
// - oidc.ErrAuthorizationPending should be returned on each poll, while the flow is not completed by the user.
// - oidc.ErrSlowDown signals to the device that the polling interval is to be increased by 5 seconds.
// - oidc.ErrAccessDenied when the authorization request is denied.
// - oidc.ErrExpiredToken when the device code has expired.
//
// A token should be returned once the authorization flow is completed
// by the user.
DeviceAccessPoll(ctx context.Context, deviceCode string) (Client, error)
// ReleaseDeviceAccessToken releases DeviceAccessPoll to return the Access Token,
// destined for a user code.
ReleaseDeviceAccessToken(ctx context.Context, userCode string) error
}

View file

@ -19,6 +19,7 @@ type Exchanger interface {
GrantTypeTokenExchangeSupported() bool GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool GrantTypeJWTAuthorizationSupported() bool
GrantTypeClientCredentialsSupported() bool GrantTypeClientCredentialsSupported() bool
GrantTypeDeviceCodeSupported() bool
AccessTokenVerifier(context.Context) AccessTokenVerifier AccessTokenVerifier(context.Context) AccessTokenVerifier
IDTokenHintVerifier(context.Context) IDTokenHintVerifier IDTokenHintVerifier(context.Context) IDTokenHintVerifier
} }
@ -56,6 +57,11 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ClientCredentialsExchange(w, r, exchanger) ClientCredentialsExchange(w, r, exchanger)
return return
} }
case string(oidc.GrantTypeDeviceCode):
if exchanger.GrantTypeDeviceCodeSupported() {
DeviceAccessToken(w, r, exchanger)
return
}
case "": case "":
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return return