Merge branch 'main' into main-to-next
This commit is contained in:
commit
8dff7ddee0
27 changed files with 308 additions and 146 deletions
|
@ -67,7 +67,7 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *
|
|||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
|
@ -273,9 +273,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
|
|||
return scopes, nil
|
||||
}
|
||||
|
||||
// checkURIAginstRedirects just checks aginst the valid redirect URIs and ignores
|
||||
// checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores
|
||||
// other factors.
|
||||
func checkURIAginstRedirects(client Client, uri string) error {
|
||||
func checkURIAgainstRedirects(client Client, uri string) error {
|
||||
if str.Contains(client.RedirectURIs(), uri) {
|
||||
return nil
|
||||
}
|
||||
|
@ -302,12 +302,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
|
|||
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
|
||||
}
|
||||
if strings.HasPrefix(uri, "https://") {
|
||||
return checkURIAginstRedirects(client, uri)
|
||||
return checkURIAgainstRedirects(client, uri)
|
||||
}
|
||||
if client.ApplicationType() == ApplicationTypeNative {
|
||||
return validateAuthReqRedirectURINative(client, uri, responseType)
|
||||
}
|
||||
if err := checkURIAginstRedirects(client, uri); err != nil {
|
||||
if err := checkURIAgainstRedirects(client, uri); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.HasPrefix(uri, "http://") {
|
||||
|
@ -328,7 +328,7 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
|
|||
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
|
||||
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
|
||||
isCustomSchema := !strings.HasPrefix(uri, "http://")
|
||||
if err := checkURIAginstRedirects(client, uri); err == nil {
|
||||
if err := checkURIAgainstRedirects(client, uri); err == nil {
|
||||
if client.DevMode() {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
tu "github.com/zitadel/oidc/v3/internal/testutil"
|
||||
|
@ -19,60 +20,34 @@ import (
|
|||
"github.com/zitadel/schema"
|
||||
)
|
||||
|
||||
//
|
||||
// TOOD: tests will be implemented in branch for service accounts
|
||||
// func TestAuthorize(t *testing.T) {
|
||||
// // testCallback := func(t *testing.T, clienID string) callbackHandler {
|
||||
// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
|
||||
// // // require.Equal(t, clientID, client.)
|
||||
// // }
|
||||
// // }
|
||||
// // testErr := func(t *testing.T, expected error) errorHandler {
|
||||
// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// // require.Equal(t, expected, err)
|
||||
// // }
|
||||
// // }
|
||||
// type args struct {
|
||||
// w http.ResponseWriter
|
||||
// r *http.Request
|
||||
// authorizer op.Authorizer
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// args args
|
||||
// }{
|
||||
// {
|
||||
// "parsing fails",
|
||||
// args{
|
||||
// httptest.NewRecorder(),
|
||||
// &http.Request{Method: "POST", Body: nil},
|
||||
// mock.NewAuthorizerExpectValid(t, true),
|
||||
// // testCallback(t, ""),
|
||||
// // testErr(t, ErrInvalidRequest("cannot parse form")),
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// "decoding fails",
|
||||
// args{
|
||||
// httptest.NewRecorder(),
|
||||
// func() *http.Request {
|
||||
// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
|
||||
// r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
// return r
|
||||
// }(),
|
||||
// mock.NewAuthorizerExpectValid(t, true),
|
||||
// // testCallback(t, ""),
|
||||
// // testErr(t, ErrInvalidRequest("cannot parse auth request")),
|
||||
// },
|
||||
// },
|
||||
// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
|
||||
// })
|
||||
// }
|
||||
//}
|
||||
func TestAuthorize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
expect func(a *mock.MockAuthorizerMockRecorder)
|
||||
}{
|
||||
{
|
||||
name: "parse error", // used to panic, see issue #315
|
||||
req: httptest.NewRequest(http.MethodPost, "/?;", nil),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
authorizer := mock.NewMockAuthorizer(gomock.NewController(t))
|
||||
|
||||
expect := authorizer.EXPECT()
|
||||
expect.Decoder().Return(schema.NewDecoder())
|
||||
expect.Encoder().Return(schema.NewEncoder())
|
||||
|
||||
if tt.expect != nil {
|
||||
tt.expect(expect)
|
||||
}
|
||||
|
||||
op.Authorize(w, tt.req, authorizer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthorizeRequest(t *testing.T) {
|
||||
type args struct {
|
||||
|
|
|
@ -56,6 +56,12 @@ type Client interface {
|
|||
// interpretation. Redirect URIs that match either the non-glob version or the
|
||||
// glob version will be accepted. Glob URIs are only partially supported for native
|
||||
// clients: "http://" is not allowed except for loopback or in dev mode.
|
||||
//
|
||||
// Note that globbing / wildcards are not permitted by the OIDC
|
||||
// standard and implementing this interface can have security implications.
|
||||
// It is advised to only return a client of this type in rare cases,
|
||||
// such as DevMode for the client being enabled.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type HasRedirectGlobs interface {
|
||||
RedirectURIGlobs() []string
|
||||
PostLogoutRedirectURIGlobs() []string
|
||||
|
@ -145,21 +151,30 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
|||
}
|
||||
|
||||
data := new(clientData)
|
||||
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
|
||||
if err = p.Decoder().Decode(data, r.Form); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
JWTProfile, ok := p.(ClientJWTProfile)
|
||||
if ok {
|
||||
if ok && data.ClientAssertion != "" {
|
||||
// if JWTProfile is supported and client sent an assertion, check it and use it as response
|
||||
// regardless if it succeeded or failed
|
||||
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
|
||||
return clientID, err == nil, err
|
||||
}
|
||||
if !ok || errors.Is(err, ErrNoClientCredentials) {
|
||||
clientID, err = ClientBasicAuth(r, p.Storage())
|
||||
}
|
||||
// try basic auth
|
||||
clientID, err = ClientBasicAuth(r, p.Storage())
|
||||
// if that succeeded, use it
|
||||
if err == nil {
|
||||
return clientID, true, nil
|
||||
}
|
||||
// if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials`
|
||||
// but return other errors immediately
|
||||
if err != nil && !errors.Is(err, ErrNoClientCredentials) {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
// if the client did not authenticate (public clients) it must at least send a client_id
|
||||
if data.ClientID == "" {
|
||||
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -18,7 +19,14 @@ import (
|
|||
type DeviceAuthorizationConfig struct {
|
||||
Lifetime time.Duration
|
||||
PollInterval time.Duration
|
||||
UserFormURL string // the URL where the user must go to authorize the device
|
||||
|
||||
// UserFormURL is the complete URL where the user must go to authorize the device.
|
||||
// Deprecated: use UserFormPath instead.
|
||||
UserFormURL string
|
||||
|
||||
// UserFormPath is the path where the user must go to authorize the device.
|
||||
// The hostname for the URL is taken from the request by IssuerFromContext.
|
||||
UserFormPath string
|
||||
UserCode UserCodeConfig
|
||||
}
|
||||
|
||||
|
@ -82,15 +90,28 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
|||
return err
|
||||
}
|
||||
|
||||
var verification *url.URL
|
||||
if config.UserFormURL != "" {
|
||||
if verification, err = url.Parse(config.UserFormURL); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
}
|
||||
} else {
|
||||
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
}
|
||||
verification.Path = config.UserFormPath
|
||||
}
|
||||
|
||||
response := &oidc.DeviceAuthorizationResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: config.UserFormURL,
|
||||
VerificationURI: verification.String(),
|
||||
ExpiresIn: int(config.Lifetime / time.Second),
|
||||
Interval: int(config.PollInterval / time.Second),
|
||||
}
|
||||
|
||||
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode)
|
||||
verification.RawQuery = "user_code=" + userCode
|
||||
response.VerificationURIComplete = verification.String()
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
@ -20,29 +21,60 @@ import (
|
|||
)
|
||||
|
||||
func Test_deviceAuthorizationHandler(t *testing.T) {
|
||||
req := &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: []string{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
type conf struct {
|
||||
UserFormURL string
|
||||
UserFormPath string
|
||||
}
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(req, values)
|
||||
body := strings.NewReader(values.Encode())
|
||||
tests := []struct {
|
||||
name string
|
||||
conf conf
|
||||
}{
|
||||
{
|
||||
name: "UserFormURL",
|
||||
conf: conf{
|
||||
UserFormURL: "https://localhost:9998/device",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UserFormPath",
|
||||
conf: conf{
|
||||
UserFormPath: "/device",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conf := gu.PtrCopy(testConfig)
|
||||
conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
|
||||
conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
|
||||
provider := newTestProvider(conf)
|
||||
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req := &oidc.DeviceAuthorizationRequest{
|
||||
Scopes: []string{"foo", "bar"},
|
||||
ClientID: "web",
|
||||
}
|
||||
values := make(url.Values)
|
||||
testProvider.Encoder().Encode(req, values)
|
||||
body := strings.NewReader(values.Encode())
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/", body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
|
||||
|
||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||
op.DeviceAuthorizationHandler(testProvider)(w, r)
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
result := w.Result()
|
||||
runWithRandReader(mr.New(mr.NewSource(1)), func() {
|
||||
op.DeviceAuthorizationHandler(provider)(w, r)
|
||||
})
|
||||
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
result := w.Result()
|
||||
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||
assert.Less(t, result.StatusCode, 300)
|
||||
|
||||
got, _ := io.ReadAll(result.Body)
|
||||
assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeviceCodeRequest(t *testing.T) {
|
||||
|
|
10
pkg/op/op.go
10
pkg/op/op.go
|
@ -480,6 +480,16 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.DeviceAuthorization = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
o.endpoints.Authorization = auth
|
||||
|
|
|
@ -20,15 +20,9 @@ import (
|
|||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var testProvider op.OpenIDProvider
|
||||
|
||||
const (
|
||||
testIssuer = "https://localhost:9998/"
|
||||
pathLoggedOut = "/logged-out"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &op.Config{
|
||||
var (
|
||||
testProvider op.OpenIDProvider
|
||||
testConfig = &op.Config{
|
||||
CryptoKey: sha256.Sum256([]byte("test")),
|
||||
DefaultLogoutRedirectURI: pathLoggedOut,
|
||||
CodeMethodS256: true,
|
||||
|
@ -40,24 +34,35 @@ func init() {
|
|||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
PollInterval: 5 * time.Second,
|
||||
UserFormURL: testIssuer + "device",
|
||||
UserFormPath: "/device",
|
||||
UserCode: op.UserCodeBase20,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
testIssuer = "https://localhost:9998/"
|
||||
pathLoggedOut = "/logged-out"
|
||||
)
|
||||
|
||||
func init() {
|
||||
storage.RegisterClients(
|
||||
storage.NativeClient("native"),
|
||||
storage.WebClient("web", "secret", "https://example.com"),
|
||||
storage.WebClient("api", "secret"),
|
||||
)
|
||||
|
||||
var err error
|
||||
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
|
||||
testProvider = newTestProvider(testConfig)
|
||||
}
|
||||
|
||||
func newTestProvider(config *op.Config) op.OpenIDProvider {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, config,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
type routesTestStorage interface {
|
||||
|
|
|
@ -113,6 +113,8 @@ type OPStorage interface {
|
|||
// handle the current request.
|
||||
GetClientByClientID(ctx context.Context, clientID string) (Client, error)
|
||||
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
|
||||
// SetUserinfoFromScopes is deprecated and should have an empty implementation for now.
|
||||
// Implement SetUserinfoFromRequest instead.
|
||||
SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
|
||||
SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
|
||||
SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
|
||||
|
@ -127,6 +129,13 @@ type JWTProfileTokenStorage interface {
|
|||
JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
|
||||
}
|
||||
|
||||
// CanSetUserinfoFromRequest is an optional additional interface that may be implemented by
|
||||
// implementors of Storage. It allows additional data to be set in id_tokens based on the
|
||||
// request.
|
||||
type CanSetUserinfoFromRequest interface {
|
||||
SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
|
||||
}
|
||||
|
||||
// Storage is a required parameter for NewOpenIDProvider(). In addition to the
|
||||
// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
|
||||
// then the grant type "client_credentials" will be supported. In that case, the access
|
||||
|
|
|
@ -190,6 +190,12 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok {
|
||||
err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
claims.SetUserInfo(userInfo)
|
||||
}
|
||||
if code != "" {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue