feat(op): ID token for device authorization grant

closes #426
This commit is contained in:
Tim Möhlmann 2023-12-13 18:05:37 +02:00
parent 9d12d1d900
commit 3a16504990
4 changed files with 116 additions and 5 deletions

View file

@ -39,7 +39,7 @@ type AuthRequest struct {
CodeChallenge *OIDCCodeChallenge CodeChallenge *OIDCCodeChallenge
done bool done bool
authTime time.Time AuthTime time.Time
} }
// LogValue allows you to define which fields will be logged. // LogValue allows you to define which fields will be logged.
@ -76,7 +76,7 @@ func (a *AuthRequest) GetAudience() []string {
} }
func (a *AuthRequest) GetAuthTime() time.Time { func (a *AuthRequest) GetAuthTime() time.Time {
return a.authTime return a.AuthTime
} }
func (a *AuthRequest) GetClientID() string { func (a *AuthRequest) GetClientID() string {

View file

@ -771,7 +771,7 @@ func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenEx
func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) { func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) {
authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access) authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access)
if ok { if ok {
return authReq.ApplicationID, authReq.authTime, authReq.GetAMR() return authReq.ApplicationID, authReq.AuthTime, authReq.GetAMR()
} }
refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request
if ok { if ok {

View file

@ -14,6 +14,7 @@ import (
httphelper "github.com/zitadel/oidc/v3/pkg/http" httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/exp/slices"
) )
type DeviceAuthorizationConfig struct { type DeviceAuthorizationConfig struct {
@ -291,15 +292,27 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str
} }
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse")
defer span.End()
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &oidc.AccessTokenResponse{ response := &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
TokenType: oidc.BearerToken, TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()), ExpiresIn: uint64(validity.Seconds()),
}, nil }
if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && slices.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client)
if err != nil {
return nil, err
}
}
return response, nil
} }

View file

@ -453,3 +453,101 @@ func TestCheckDeviceAuthorizationState(t *testing.T) {
}) })
} }
} }
func TestCreateDeviceTokenResponse(t *testing.T) {
tests := []struct {
name string
tokenRequest op.TokenRequest
wantAccessToken bool
wantRefreshToken bool
wantIDToken bool
wantErr bool
}{
{
name: "access token",
tokenRequest: &storage.AuthRequest{
ID: "auth1",
AuthTime: time.Now(),
ApplicationID: "app1",
ResponseType: oidc.ResponseTypeCode,
UserID: "id1",
},
wantAccessToken: true,
},
{
name: "access and refresh tokens",
tokenRequest: &storage.AuthRequest{
ID: "auth1",
AuthTime: time.Now(),
ApplicationID: "app1",
ResponseType: oidc.ResponseTypeCode,
UserID: "id1",
Scopes: []string{oidc.ScopeOfflineAccess},
},
wantAccessToken: true,
wantRefreshToken: true,
},
{
name: "access and id token",
tokenRequest: &storage.AuthRequest{
ID: "auth1",
AuthTime: time.Now(),
ApplicationID: "app1",
ResponseType: oidc.ResponseTypeCode,
UserID: "id1",
Scopes: []string{oidc.ScopeOpenID},
},
wantAccessToken: true,
wantIDToken: true,
},
{
name: "access, refresh and id token",
tokenRequest: &storage.AuthRequest{
ID: "auth1",
AuthTime: time.Now(),
ApplicationID: "app1",
ResponseType: oidc.ResponseTypeCode,
UserID: "id1",
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
},
wantAccessToken: true,
wantRefreshToken: true,
wantIDToken: true,
},
{
name: "id token creation error",
tokenRequest: &storage.AuthRequest{
ID: "auth1",
AuthTime: time.Now(),
ApplicationID: "app1",
ResponseType: oidc.ResponseTypeCode,
UserID: "foobar",
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native")
require.NoError(t, err)
got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.InDelta(t, 300, got.ExpiresIn, 2)
if tt.wantAccessToken {
assert.NotEmpty(t, got.AccessToken, "access token")
}
if tt.wantRefreshToken {
assert.NotEmpty(t, got.RefreshToken, "refresh token")
}
if tt.wantIDToken {
assert.NotEmpty(t, got.IDToken, "id token")
}
})
}
}