diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 5604483..1dc8bd1 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -107,7 +107,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: issuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } diff --git a/pkg/op/device.go b/pkg/op/device.go index 04c06f2..f7691ca 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -8,6 +8,7 @@ import ( "fmt" "math/big" "net/http" + "net/url" "strings" "time" @@ -18,7 +19,14 @@ import ( type DeviceAuthorizationConfig struct { Lifetime time.Duration PollInterval time.Duration - UserFormURL string // the URL where the user must go to authorize the device + + // UserFormURL is the complete URL where the user must go to authorize the device. + // Deprecated: use UserFormPath instead. + UserFormURL string + + // UserFormPath is the path where the user must go to authorize the device. + // The hostname for the URL is taken from the request by IssuerFromContext. + UserFormPath string UserCode UserCodeConfig } @@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide return err } + var verification *url.URL + if config.UserFormURL != "" { + if verification, err = url.Parse(config.UserFormURL); err != nil { + return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + } + } else { + if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { + return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + } + verification.Path = config.UserFormPath + } + response := &oidc.DeviceAuthorizationResponse{ DeviceCode: deviceCode, UserCode: userCode, - VerificationURI: config.UserFormURL, + VerificationURI: verification.String(), ExpiresIn: int(config.Lifetime / time.Second), Interval: int(config.PollInterval / time.Second), } - response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) + verification.RawQuery = "user_code=" + userCode + response.VerificationURIComplete = verification.String() httphelper.MarshalJSON(w, response) return nil diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 69ba102..cf94c3f 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v2/pkg/oidc" @@ -20,29 +21,60 @@ import ( ) func Test_deviceAuthorizationHandler(t *testing.T) { - req := &oidc.DeviceAuthorizationRequest{ - Scopes: []string{"foo", "bar"}, - ClientID: "web", + type conf struct { + UserFormURL string + UserFormPath string } - values := make(url.Values) - testProvider.Encoder().Encode(req, values) - body := strings.NewReader(values.Encode()) + tests := []struct { + name string + conf conf + }{ + { + name: "UserFormURL", + conf: conf{ + UserFormURL: "https://localhost:9998/device", + }, + }, + { + name: "UserFormPath", + conf: conf{ + UserFormPath: "/device", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := gu.PtrCopy(testConfig) + conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL + conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath + provider := newTestProvider(conf) - r := httptest.NewRequest(http.MethodPost, "/", body) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req := &oidc.DeviceAuthorizationRequest{ + Scopes: []string{"foo", "bar"}, + ClientID: "web", + } + values := make(url.Values) + testProvider.Encoder().Encode(req, values) + body := strings.NewReader(values.Encode()) - w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer)) - runWithRandReader(mr.New(mr.NewSource(1)), func() { - op.DeviceAuthorizationHandler(testProvider)(w, r) - }) + w := httptest.NewRecorder() - result := w.Result() + runWithRandReader(mr.New(mr.NewSource(1)), func() { + op.DeviceAuthorizationHandler(provider)(w, r) + }) - assert.Less(t, result.StatusCode, 300) + result := w.Result() - got, _ := io.ReadAll(result.Body) - assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + assert.Less(t, result.StatusCode, 300) + + got, _ := io.ReadAll(result.Body) + assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + }) + } } func TestParseDeviceCodeRequest(t *testing.T) { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index ba3570b..3e6377f 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -20,15 +20,9 @@ import ( "golang.org/x/text/language" ) -var testProvider op.OpenIDProvider - -const ( - testIssuer = "https://localhost:9998/" - pathLoggedOut = "/logged-out" -) - -func init() { - config := &op.Config{ +var ( + testProvider op.OpenIDProvider + testConfig = &op.Config{ CryptoKey: sha256.Sum256([]byte("test")), DefaultLogoutRedirectURI: pathLoggedOut, CodeMethodS256: true, @@ -40,24 +34,35 @@ func init() { DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: testIssuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } +) +const ( + testIssuer = "https://localhost:9998/" + pathLoggedOut = "/logged-out" +) + +func init() { storage.RegisterClients( storage.NativeClient("native"), storage.WebClient("web", "secret", "https://example.com"), storage.WebClient("api", "secret"), ) - var err error - testProvider, err = op.NewOpenIDProvider(testIssuer, config, + testProvider = newTestProvider(testConfig) +} + +func newTestProvider(config *op.Config) op.OpenIDProvider { + provider, err := op.NewOpenIDProvider(testIssuer, config, storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), ) if err != nil { panic(err) } + return provider } type routesTestStorage interface {