feat: get issuer from context for device auth (#363)
* feat: get issuer from context for device auth * use distinct UserFormURL and UserFormPath - Properly deprecate UserFormURL and default to old behaviour, to prevent breaking change. - Refactor unit tests to test both cases. * update example
This commit is contained in:
parent
97bc09583d
commit
44f8403574
4 changed files with 90 additions and 32 deletions
|
@ -107,7 +107,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
|
||||||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||||
Lifetime: 5 * time.Minute,
|
Lifetime: 5 * time.Minute,
|
||||||
PollInterval: 5 * time.Second,
|
PollInterval: 5 * time.Second,
|
||||||
UserFormURL: issuer + "device",
|
UserFormPath: "/device",
|
||||||
UserCode: op.UserCodeBase20,
|
UserCode: op.UserCodeBase20,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -18,7 +19,14 @@ import (
|
||||||
type DeviceAuthorizationConfig struct {
|
type DeviceAuthorizationConfig struct {
|
||||||
Lifetime time.Duration
|
Lifetime time.Duration
|
||||||
PollInterval time.Duration
|
PollInterval time.Duration
|
||||||
UserFormURL string // the URL where the user must go to authorize the device
|
|
||||||
|
// UserFormURL is the complete URL where the user must go to authorize the device.
|
||||||
|
// Deprecated: use UserFormPath instead.
|
||||||
|
UserFormURL string
|
||||||
|
|
||||||
|
// UserFormPath is the path where the user must go to authorize the device.
|
||||||
|
// The hostname for the URL is taken from the request by IssuerFromContext.
|
||||||
|
UserFormPath string
|
||||||
UserCode UserCodeConfig
|
UserCode UserCodeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var verification *url.URL
|
||||||
|
if config.UserFormURL != "" {
|
||||||
|
if verification, err = url.Parse(config.UserFormURL); err != nil {
|
||||||
|
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
|
||||||
|
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||||
|
}
|
||||||
|
verification.Path = config.UserFormPath
|
||||||
|
}
|
||||||
|
|
||||||
response := &oidc.DeviceAuthorizationResponse{
|
response := &oidc.DeviceAuthorizationResponse{
|
||||||
DeviceCode: deviceCode,
|
DeviceCode: deviceCode,
|
||||||
UserCode: userCode,
|
UserCode: userCode,
|
||||||
VerificationURI: config.UserFormURL,
|
VerificationURI: verification.String(),
|
||||||
ExpiresIn: int(config.Lifetime / time.Second),
|
ExpiresIn: int(config.Lifetime / time.Second),
|
||||||
Interval: int(config.PollInterval / time.Second),
|
Interval: int(config.PollInterval / time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
|
verification.RawQuery = "user_code=" + userCode
|
||||||
|
response.VerificationURIComplete = verification.String()
|
||||||
|
|
||||||
httphelper.MarshalJSON(w, response)
|
httphelper.MarshalJSON(w, response)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
@ -20,29 +21,60 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_deviceAuthorizationHandler(t *testing.T) {
|
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||||
req := &oidc.DeviceAuthorizationRequest{
|
type conf struct {
|
||||||
Scopes: []string{"foo", "bar"},
|
UserFormURL string
|
||||||
ClientID: "web",
|
UserFormPath string
|
||||||
}
|
}
|
||||||
values := make(url.Values)
|
tests := []struct {
|
||||||
testProvider.Encoder().Encode(req, values)
|
name string
|
||||||
body := strings.NewReader(values.Encode())
|
conf conf
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "UserFormURL",
|
||||||
|
conf: conf{
|
||||||
|
UserFormURL: "https://localhost:9998/device",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserFormPath",
|
||||||
|
conf: conf{
|
||||||
|
UserFormPath: "/device",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
conf := gu.PtrCopy(testConfig)
|
||||||
|
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
|
||||||
|
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
|
||||||
|
provider := newTestProvider(conf)
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
req := &oidc.DeviceAuthorizationRequest{
|
||||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
Scopes: []string{"foo", "bar"},
|
||||||
|
ClientID: "web",
|
||||||
|
}
|
||||||
|
values := make(url.Values)
|
||||||
|
testProvider.Encoder().Encode(req, values)
|
||||||
|
body := strings.NewReader(values.Encode())
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
|
||||||
|
|
||||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
w := httptest.NewRecorder()
|
||||||
op.DeviceAuthorizationHandler(testProvider)(w, r)
|
|
||||||
})
|
|
||||||
|
|
||||||
result := w.Result()
|
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||||
|
op.DeviceAuthorizationHandler(provider)(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
assert.Less(t, result.StatusCode, 300)
|
result := w.Result()
|
||||||
|
|
||||||
got, _ := io.ReadAll(result.Body)
|
assert.Less(t, result.StatusCode, 300)
|
||||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
|
||||||
|
got, _ := io.ReadAll(result.Body)
|
||||||
|
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseDeviceCodeRequest(t *testing.T) {
|
func TestParseDeviceCodeRequest(t *testing.T) {
|
||||||
|
|
|
@ -20,15 +20,9 @@ import (
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testProvider op.OpenIDProvider
|
var (
|
||||||
|
testProvider op.OpenIDProvider
|
||||||
const (
|
testConfig = &op.Config{
|
||||||
testIssuer = "https://localhost:9998/"
|
|
||||||
pathLoggedOut = "/logged-out"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
config := &op.Config{
|
|
||||||
CryptoKey: sha256.Sum256([]byte("test")),
|
CryptoKey: sha256.Sum256([]byte("test")),
|
||||||
DefaultLogoutRedirectURI: pathLoggedOut,
|
DefaultLogoutRedirectURI: pathLoggedOut,
|
||||||
CodeMethodS256: true,
|
CodeMethodS256: true,
|
||||||
|
@ -40,24 +34,35 @@ func init() {
|
||||||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||||
Lifetime: 5 * time.Minute,
|
Lifetime: 5 * time.Minute,
|
||||||
PollInterval: 5 * time.Second,
|
PollInterval: 5 * time.Second,
|
||||||
UserFormURL: testIssuer + "device",
|
UserFormPath: "/device",
|
||||||
UserCode: op.UserCodeBase20,
|
UserCode: op.UserCodeBase20,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testIssuer = "https://localhost:9998/"
|
||||||
|
pathLoggedOut = "/logged-out"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
storage.RegisterClients(
|
storage.RegisterClients(
|
||||||
storage.NativeClient("native"),
|
storage.NativeClient("native"),
|
||||||
storage.WebClient("web", "secret", "https://example.com"),
|
storage.WebClient("web", "secret", "https://example.com"),
|
||||||
storage.WebClient("api", "secret"),
|
storage.WebClient("api", "secret"),
|
||||||
)
|
)
|
||||||
|
|
||||||
var err error
|
testProvider = newTestProvider(testConfig)
|
||||||
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
|
}
|
||||||
|
|
||||||
|
func newTestProvider(config *op.Config) op.OpenIDProvider {
|
||||||
|
provider, err := op.NewOpenIDProvider(testIssuer, config,
|
||||||
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
return provider
|
||||||
}
|
}
|
||||||
|
|
||||||
type routesTestStorage interface {
|
type routesTestStorage interface {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue