feat: get issuer from context for device auth (#363)

* feat: get issuer from context for device auth

* use distinct UserFormURL and UserFormPath

- Properly deprecate UserFormURL and default to old behaviour,
to prevent breaking change.

- Refactor unit tests to test both cases.

* update example
This commit is contained in:
Tim Möhlmann 2023-04-11 21:29:17 +03:00 committed by GitHub
parent 97bc09583d
commit 44f8403574
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 32 deletions

View file

@ -107,7 +107,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
DeviceAuthorization: op.DeviceAuthorizationConfig{ DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute, Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second, PollInterval: 5 * time.Second,
UserFormURL: issuer + "device", UserFormPath: "/device",
UserCode: op.UserCodeBase20, UserCode: op.UserCodeBase20,
}, },
} }

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@ -18,7 +19,14 @@ import (
type DeviceAuthorizationConfig struct { type DeviceAuthorizationConfig struct {
Lifetime time.Duration Lifetime time.Duration
PollInterval time.Duration PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
// UserFormURL is the complete URL where the user must go to authorize the device.
// Deprecated: use UserFormPath instead.
UserFormURL string
// UserFormPath is the path where the user must go to authorize the device.
// The hostname for the URL is taken from the request by IssuerFromContext.
UserFormPath string
UserCode UserCodeConfig UserCode UserCodeConfig
} }
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
return err return err
} }
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
}
verification.Path = config.UserFormPath
}
response := &oidc.DeviceAuthorizationResponse{ response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode, DeviceCode: deviceCode,
UserCode: userCode, UserCode: userCode,
VerificationURI: config.UserFormURL, VerificationURI: verification.String(),
ExpiresIn: int(config.Lifetime / time.Second), ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second), Interval: int(config.PollInterval / time.Second),
} }
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return nil return nil

View file

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/oidc"
@ -20,29 +21,60 @@ import (
) )
func Test_deviceAuthorizationHandler(t *testing.T) { func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{ type conf struct {
Scopes: []string{"foo", "bar"}, UserFormURL string
ClientID: "web", UserFormPath string
} }
values := make(url.Values) tests := []struct {
testProvider.Encoder().Encode(req, values) name string
body := strings.NewReader(values.Encode()) conf conf
}{
{
name: "UserFormURL",
conf: conf{
UserFormURL: "https://localhost:9998/device",
},
},
{
name: "UserFormPath",
conf: conf{
UserFormPath: "/device",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := gu.PtrCopy(testConfig)
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
provider := newTestProvider(conf)
r := httptest.NewRequest(http.MethodPost, "/", body) req := &oidc.DeviceAuthorizationRequest{
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
runWithRandReader(mr.New(mr.NewSource(1)), func() { w := httptest.NewRecorder()
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
result := w.Result() runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(provider)(w, r)
})
assert.Less(t, result.StatusCode, 300) result := w.Result()
got, _ := io.ReadAll(result.Body) assert.Less(t, result.StatusCode, 300)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
})
}
} }
func TestParseDeviceCodeRequest(t *testing.T) { func TestParseDeviceCodeRequest(t *testing.T) {

View file

@ -20,15 +20,9 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
) )
var testProvider op.OpenIDProvider var (
testProvider op.OpenIDProvider
const ( testConfig = &op.Config{
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")), CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut, DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true, CodeMethodS256: true,
@ -40,24 +34,35 @@ func init() {
DeviceAuthorization: op.DeviceAuthorizationConfig{ DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute, Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second, PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device", UserFormPath: "/device",
UserCode: op.UserCodeBase20, UserCode: op.UserCodeBase20,
}, },
} }
)
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
storage.RegisterClients( storage.RegisterClients(
storage.NativeClient("native"), storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"), storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"), storage.WebClient("api", "secret"),
) )
var err error testProvider = newTestProvider(testConfig)
testProvider, err = op.NewOpenIDProvider(testIssuer, config, }
func newTestProvider(config *op.Config) op.OpenIDProvider {
provider, err := op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
) )
if err != nil { if err != nil {
panic(err) panic(err)
} }
return provider
} }
type routesTestStorage interface { type routesTestStorage interface {