feat: get issuer from context for device auth (#363)

* feat: get issuer from context for device auth

* use distinct UserFormURL and UserFormPath

- Properly deprecate UserFormURL and default to old behaviour,
to prevent breaking change.

- Refactor unit tests to test both cases.

* update example
This commit is contained in:
Tim Möhlmann 2023-04-11 21:29:17 +03:00 committed by GitHub
parent 97bc09583d
commit 44f8403574
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 32 deletions

View file

@ -107,7 +107,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: issuer + "device",
UserFormPath: "/device",
UserCode: op.UserCodeBase20,
},
}

View file

@ -8,6 +8,7 @@ import (
"fmt"
"math/big"
"net/http"
"net/url"
"strings"
"time"
@ -18,7 +19,14 @@ import (
type DeviceAuthorizationConfig struct {
Lifetime time.Duration
PollInterval time.Duration
UserFormURL string // the URL where the user must go to authorize the device
// UserFormURL is the complete URL where the user must go to authorize the device.
// Deprecated: use UserFormPath instead.
UserFormURL string
// UserFormPath is the path where the user must go to authorize the device.
// The hostname for the URL is taken from the request by IssuerFromContext.
UserFormPath string
UserCode UserCodeConfig
}
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
return err
}
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
}
verification.Path = config.UserFormPath
}
response := &oidc.DeviceAuthorizationResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: config.UserFormURL,
VerificationURI: verification.String(),
ExpiresIn: int(config.Lifetime / time.Second),
Interval: int(config.PollInterval / time.Second),
}
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response)
return nil

View file

@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
@ -20,29 +21,60 @@ import (
)
func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
type conf struct {
UserFormURL string
UserFormPath string
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
tests := []struct {
name string
conf conf
}{
{
name: "UserFormURL",
conf: conf{
UserFormURL: "https://localhost:9998/device",
},
},
{
name: "UserFormPath",
conf: conf{
UserFormPath: "/device",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := gu.PtrCopy(testConfig)
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
provider := newTestProvider(conf)
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},
ClientID: "web",
}
values := make(url.Values)
testProvider.Encoder().Encode(req, values)
body := strings.NewReader(values.Encode())
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/", body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(testProvider)(w, r)
})
w := httptest.NewRecorder()
result := w.Result()
runWithRandReader(mr.New(mr.NewSource(1)), func() {
op.DeviceAuthorizationHandler(provider)(w, r)
})
assert.Less(t, result.StatusCode, 300)
result := w.Result()
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
assert.Less(t, result.StatusCode, 300)
got, _ := io.ReadAll(result.Body)
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
})
}
}
func TestParseDeviceCodeRequest(t *testing.T) {

View file

@ -20,15 +20,9 @@ import (
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
var (
testProvider op.OpenIDProvider
testConfig = &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
@ -40,24 +34,35 @@ func init() {
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserFormPath: "/device",
UserCode: op.UserCodeBase20,
},
}
)
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
testProvider = newTestProvider(testConfig)
}
func newTestProvider(config *op.Config) op.OpenIDProvider {
provider, err := op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
return provider
}
type routesTestStorage interface {