feat: get issuer from context for device auth (#363)
* feat: get issuer from context for device auth * use distinct UserFormURL and UserFormPath - Properly deprecate UserFormURL and default to old behaviour, to prevent breaking change. - Refactor unit tests to test both cases. * update example
This commit is contained in:
parent
97bc09583d
commit
44f8403574
4 changed files with 90 additions and 32 deletions
|
@ -107,7 +107,7 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider,
|
|||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
PollInterval: 5 * time.Second,
|
||||
UserFormURL: issuer + "device",
|
||||
UserFormPath: "/device",
|
||||
UserCode: op.UserCodeBase20,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -18,7 +19,14 @@ import (
|
|||
type DeviceAuthorizationConfig struct {
|
||||
Lifetime time.Duration
|
||||
PollInterval time.Duration
|
||||
UserFormURL string // the URL where the user must go to authorize the device
|
||||
|
||||
// UserFormURL is the complete URL where the user must go to authorize the device.
|
||||
// Deprecated: use UserFormPath instead.
|
||||
UserFormURL string
|
||||
|
||||
// UserFormPath is the path where the user must go to authorize the device.
|
||||
// The hostname for the URL is taken from the request by IssuerFromContext.
|
||||
UserFormPath string
|
||||
UserCode UserCodeConfig
|
||||
}
|
||||
|
||||
|
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
|||
return err
|
||||
}
|
||||
|
||||
var verification *url.URL
|
||||
if config.UserFormURL != "" {
|
||||
if verification, err = url.Parse(config.UserFormURL); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
}
|
||||
} else {
|
||||
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
}
|
||||
verification.Path = config.UserFormPath
|
||||
}
|
||||
|
||||
response := &oidc.DeviceAuthorizationResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: config.UserFormURL,
|
||||
VerificationURI: verification.String(),
|
||||
ExpiresIn: int(config.Lifetime / time.Second),
|
||||
Interval: int(config.PollInterval / time.Second),
|
||||
}
|
||||
|
||||
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
|
||||
verification.RawQuery = "user_code=" + userCode
|
||||
response.VerificationURIComplete = verification.String()
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
|
@ -20,29 +21,60 @@ import (
|
|||
)
|
||||
|
||||
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||
req := &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: []string{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
type conf struct {
|
||||
UserFormURL string
|
||||
UserFormPath string
|
||||
}
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(req, values)
|
||||
body := strings.NewReader(values.Encode())
|
||||
tests := []struct {
|
||||
name string
|
||||
conf conf
|
||||
}{
|
||||
{
|
||||
name: "UserFormURL",
|
||||
conf: conf{
|
||||
UserFormURL: "https://localhost:9998/device",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserFormPath",
|
||||
conf: conf{
|
||||
UserFormPath: "/device",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conf := gu.PtrCopy(testConfig)
|
||||
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
|
||||
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
|
||||
provider := newTestProvider(conf)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req := &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: []string{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
}
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(req, values)
|
||||
body := strings.NewReader(values.Encode())
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
|
||||
|
||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||
op.DeviceAuthorizationHandler(testProvider)(w, r)
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
result := w.Result()
|
||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||
op.DeviceAuthorizationHandler(provider)(w, r)
|
||||
})
|
||||
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
result := w.Result()
|
||||
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeviceCodeRequest(t *testing.T) {
|
||||
|
|
|
@ -20,15 +20,9 @@ import (
|
|||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var testProvider op.OpenIDProvider
|
||||
|
||||
const (
|
||||
testIssuer = "https://localhost:9998/"
|
||||
pathLoggedOut = "/logged-out"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &op.Config{
|
||||
var (
|
||||
testProvider op.OpenIDProvider
|
||||
testConfig = &op.Config{
|
||||
CryptoKey: sha256.Sum256([]byte("test")),
|
||||
DefaultLogoutRedirectURI: pathLoggedOut,
|
||||
CodeMethodS256: true,
|
||||
|
@ -40,24 +34,35 @@ func init() {
|
|||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
PollInterval: 5 * time.Second,
|
||||
UserFormURL: testIssuer + "device",
|
||||
UserFormPath: "/device",
|
||||
UserCode: op.UserCodeBase20,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
testIssuer = "https://localhost:9998/"
|
||||
pathLoggedOut = "/logged-out"
|
||||
)
|
||||
|
||||
func init() {
|
||||
storage.RegisterClients(
|
||||
storage.NativeClient("native"),
|
||||
storage.WebClient("web", "secret", "https://example.com"),
|
||||
storage.WebClient("api", "secret"),
|
||||
)
|
||||
|
||||
var err error
|
||||
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
|
||||
testProvider = newTestProvider(testConfig)
|
||||
}
|
||||
|
||||
func newTestProvider(config *op.Config) op.OpenIDProvider {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, config,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
type routesTestStorage interface {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue