From 3a165049908c75ad46bebbc8e61cd4dc87c8e763 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 13 Dec 2023 18:05:37 +0200 Subject: [PATCH] feat(op): ID token for device authorization grant closes #426 --- example/server/storage/oidc.go | 4 +- example/server/storage/storage.go | 2 +- pkg/op/device.go | 17 +++++- pkg/op/device_test.go | 98 +++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 5 deletions(-) diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index 63afcf9..0ebaca6 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -39,7 +39,7 @@ type AuthRequest struct { CodeChallenge *OIDCCodeChallenge done bool - authTime time.Time + AuthTime time.Time } // LogValue allows you to define which fields will be logged. @@ -76,7 +76,7 @@ func (a *AuthRequest) GetAudience() []string { } func (a *AuthRequest) GetAuthTime() time.Time { - return a.authTime + return a.AuthTime } func (a *AuthRequest) GetClientID() string { diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index b556828..742b7a1 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -771,7 +771,7 @@ func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenEx func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) { authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access) if ok { - return authReq.ApplicationID, authReq.authTime, authReq.GetAMR() + return authReq.ApplicationID, authReq.AuthTime, authReq.GetAMR() } refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request if ok { diff --git a/pkg/op/device.go b/pkg/op/device.go index 813c3f5..5226aef 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -14,6 +14,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slices" ) type DeviceAuthorizationConfig struct { @@ -291,15 +292,27 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str } func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { + ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse") + defer span.End() + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err } - return &oidc.AccessTokenResponse{ + response := &oidc.AccessTokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), - }, nil + } + + if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && slices.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) { + response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client) + if err != nil { + return nil, err + } + } + + return response, nil } diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index f5452f9..2400598 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -453,3 +453,101 @@ func TestCheckDeviceAuthorizationState(t *testing.T) { }) } } + +func TestCreateDeviceTokenResponse(t *testing.T) { + tests := []struct { + name string + tokenRequest op.TokenRequest + wantAccessToken bool + wantRefreshToken bool + wantIDToken bool + wantErr bool + }{ + { + name: "access token", + tokenRequest: &storage.AuthRequest{ + ID: "auth1", + AuthTime: time.Now(), + ApplicationID: "app1", + ResponseType: oidc.ResponseTypeCode, + UserID: "id1", + }, + wantAccessToken: true, + }, + { + name: "access and refresh tokens", + tokenRequest: &storage.AuthRequest{ + ID: "auth1", + AuthTime: time.Now(), + ApplicationID: "app1", + ResponseType: oidc.ResponseTypeCode, + UserID: "id1", + Scopes: []string{oidc.ScopeOfflineAccess}, + }, + wantAccessToken: true, + wantRefreshToken: true, + }, + { + name: "access and id token", + tokenRequest: &storage.AuthRequest{ + ID: "auth1", + AuthTime: time.Now(), + ApplicationID: "app1", + ResponseType: oidc.ResponseTypeCode, + UserID: "id1", + Scopes: []string{oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantIDToken: true, + }, + { + name: "access, refresh and id token", + tokenRequest: &storage.AuthRequest{ + ID: "auth1", + AuthTime: time.Now(), + ApplicationID: "app1", + ResponseType: oidc.ResponseTypeCode, + UserID: "id1", + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantRefreshToken: true, + wantIDToken: true, + }, + { + name: "id token creation error", + tokenRequest: &storage.AuthRequest{ + ID: "auth1", + AuthTime: time.Now(), + ApplicationID: "app1", + ResponseType: oidc.ResponseTypeCode, + UserID: "foobar", + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native") + require.NoError(t, err) + + got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.InDelta(t, 300, got.ExpiresIn, 2) + if tt.wantAccessToken { + assert.NotEmpty(t, got.AccessToken, "access token") + } + if tt.wantRefreshToken { + assert.NotEmpty(t, got.RefreshToken, "refresh token") + } + if tt.wantIDToken { + assert.NotEmpty(t, got.IDToken, "id token") + } + }) + } +}