extract client authentication from introspection
reuse the client authentication code for device authorization and introspection.
This commit is contained in:
parent
0f9ec46aaa
commit
f26e155208
4 changed files with 571 additions and 58 deletions
122
pkg/op/client.go
122
pkg/op/client.go
|
@ -1,8 +1,14 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -57,3 +63,119 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
|
||||||
func IsConfidentialType(c Client) bool {
|
func IsConfidentialType(c Client) bool {
|
||||||
return c.ApplicationType() == ApplicationTypeWeb
|
return c.ApplicationType() == ApplicationTypeWeb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
|
||||||
|
ErrNoClientCredentials = errors.New("no client credentials provided")
|
||||||
|
ErrMissingClientID = errors.New("client_id missing from request")
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientJWTProfile interface {
|
||||||
|
JWTProfileVerifier(context.Context) JWTProfileVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
|
||||||
|
if ca.ClientAssertion == "" {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
|
||||||
|
}
|
||||||
|
return profile.Issuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
|
||||||
|
clientID, clientSecret, ok := r.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
|
}
|
||||||
|
clientID, err = url.QueryUnescape(clientID)
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||||
|
}
|
||||||
|
clientSecret, err = url.QueryUnescape(clientSecret)
|
||||||
|
if err != nil {
|
||||||
|
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
|
||||||
|
}
|
||||||
|
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
||||||
|
return "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||||
|
}
|
||||||
|
return clientID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientProvider interface {
|
||||||
|
Decoder() httphelper.Decoder
|
||||||
|
Storage() Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientData struct {
|
||||||
|
ClientID string `schema:"client_id"`
|
||||||
|
oidc.ClientAssertionParams
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientIDFromRequest parses the request form and tries to obtain the client ID
|
||||||
|
// and reports if it is authenticated, using a JWT or static client secrets over
|
||||||
|
// http basic auth.
|
||||||
|
//
|
||||||
|
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
|
||||||
|
// present in the form data, JWT assertion will be verified and the
|
||||||
|
// client ID is taken from there.
|
||||||
|
// If any of them is absent, basic auth is attempted.
|
||||||
|
// In absence of basic auth data, the unauthenticated client id from the form
|
||||||
|
// data is returned.
|
||||||
|
//
|
||||||
|
// If no client id can be obtained by any method, oidc.ErrInvalidClient
|
||||||
|
// is returned with ErrMissingClientID wrapped in it.
|
||||||
|
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
|
||||||
|
err = r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := new(clientData)
|
||||||
|
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
JWTProfile, ok := p.(ClientJWTProfile)
|
||||||
|
if ok {
|
||||||
|
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
|
||||||
|
}
|
||||||
|
if !ok || errors.Is(err, ErrNoClientCredentials) {
|
||||||
|
clientID, err = ClientBasicAuth(r, p.Storage())
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
return clientID, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.ClientID == "" {
|
||||||
|
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
|
||||||
|
}
|
||||||
|
return data.ClientID, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
|
||||||
|
// If the client id was not authenticated, the client from storage does not have
|
||||||
|
// oidc.AuthMethodNone set, an error is returned.
|
||||||
|
func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) {
|
||||||
|
clientID, authenticated, err := ClientIDFromRequest(r, p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := p.Storage().GetClientByClientID(r.Context(), clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !authenticated {
|
||||||
|
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||||
|
return nil, oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||||
|
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, err
|
||||||
|
}
|
||||||
|
|
392
pkg/op/client_test.go
Normal file
392
pkg/op/client_test.go
Normal file
|
@ -0,0 +1,392 @@
|
||||||
|
package op_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/gorilla/schema"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op"
|
||||||
|
"github.com/zitadel/oidc/v2/pkg/op/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testClientJWTProfile struct{}
|
||||||
|
|
||||||
|
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
|
||||||
|
|
||||||
|
func TestClientJWTAuth(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
ca oidc.ClientAssertionParams
|
||||||
|
verifier op.ClientJWTProfile
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantClientID string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty assertion",
|
||||||
|
args: args{
|
||||||
|
context.Background(),
|
||||||
|
oidc.ClientAssertionParams{},
|
||||||
|
testClientJWTProfile{},
|
||||||
|
},
|
||||||
|
wantErr: op.ErrNoClientCredentials,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "verification error",
|
||||||
|
args: args{
|
||||||
|
context.Background(),
|
||||||
|
oidc.ClientAssertionParams{
|
||||||
|
ClientAssertion: "foo",
|
||||||
|
},
|
||||||
|
testClientJWTProfile{},
|
||||||
|
},
|
||||||
|
wantErr: oidc.ErrParse,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientBasicAuth(t *testing.T) {
|
||||||
|
errWrong := errors.New("wrong secret")
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args *args
|
||||||
|
storage op.Storage
|
||||||
|
wantClientID string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no args",
|
||||||
|
wantErr: op.ErrNoClientCredentials,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username unescape err",
|
||||||
|
args: &args{
|
||||||
|
username: "%",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantErr: op.ErrInvalidAuthHeader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password unescape err",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "%",
|
||||||
|
},
|
||||||
|
wantErr: op.ErrInvalidAuthHeader,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth error",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "wrong",
|
||||||
|
},
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
wantErr: errWrong,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth error",
|
||||||
|
args: &args{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
wantClientID: "foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
|
||||||
|
if tt.args != nil {
|
||||||
|
r.SetBasicAuth(tt.args.username, tt.args.password)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errReader struct{}
|
||||||
|
|
||||||
|
func (errReader) Read([]byte) (int, error) {
|
||||||
|
return 0, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
type testClientProvider struct {
|
||||||
|
storage op.Storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (testClientProvider) Decoder() httphelper.Decoder {
|
||||||
|
return schema.NewDecoder()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p testClientProvider) Storage() op.Storage {
|
||||||
|
return p.storage
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientIDFromRequest(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
body io.Reader
|
||||||
|
p op.ClientProvider
|
||||||
|
}
|
||||||
|
type basicAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
basicAuth *basicAuth
|
||||||
|
wantClientID string
|
||||||
|
wantAuthenticated bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parse error",
|
||||||
|
args: args{
|
||||||
|
body: errReader{},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthenticated",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{
|
||||||
|
"client_id": []string{"foo"},
|
||||||
|
}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: mock.NewStorage(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClientID: "foo",
|
||||||
|
wantAuthenticated: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthenticated",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
basicAuth: &basicAuth{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantClientID: "foo",
|
||||||
|
wantAuthenticated: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing client id",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: mock.NewStorage(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
if tt.basicAuth != nil {
|
||||||
|
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.wantClientID, gotClientID)
|
||||||
|
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientFromRequest(t *testing.T) {
|
||||||
|
publicClient := func() op.Client {
|
||||||
|
c := mock.NewMockClient(gomock.NewController(t))
|
||||||
|
c.EXPECT().AuthMethod().Return(oidc.AuthMethodNone)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
privateClient := func() op.Client {
|
||||||
|
c := mock.NewMockClient(gomock.NewController(t))
|
||||||
|
c.EXPECT().AuthMethod().Return(oidc.AuthMethodPrivateKeyJWT)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
body io.Reader
|
||||||
|
p op.ClientProvider
|
||||||
|
}
|
||||||
|
type basicAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
basicAuth *basicAuth
|
||||||
|
wantClient bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing client id",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: mock.NewStorage(t),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get client error",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(nil, errors.New("something"))
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
basicAuth: &basicAuth{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authenticated",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||||
|
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(mock.NewClient(t), nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
basicAuth: &basicAuth{
|
||||||
|
username: "foo",
|
||||||
|
password: "bar",
|
||||||
|
},
|
||||||
|
wantClient: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{
|
||||||
|
"client_id": []string{"foo"},
|
||||||
|
}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(publicClient(), nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClient: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "false public",
|
||||||
|
args: args{
|
||||||
|
body: strings.NewReader(
|
||||||
|
url.Values{
|
||||||
|
"client_id": []string{"foo"},
|
||||||
|
}.Encode(),
|
||||||
|
),
|
||||||
|
p: testClientProvider{
|
||||||
|
storage: func() op.Storage {
|
||||||
|
s := mock.NewMockStorage(gomock.NewController(t))
|
||||||
|
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(privateClient(), nil)
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
|
||||||
|
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
if tt.basicAuth != nil {
|
||||||
|
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := op.ClientFromRequest(r, tt.args.p)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
if tt.wantClient {
|
||||||
|
assert.NotNil(t, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -48,41 +48,38 @@ var (
|
||||||
|
|
||||||
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
DeviceAuthorization(w, r, o)
|
if err := DeviceAuthorization(w, r, o); err != nil {
|
||||||
|
RequestError(w, r, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) {
|
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
||||||
storage, err := assertDeviceStorage(o.Storage())
|
storage, err := assertDeviceStorage(o.Storage())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := ParseDeviceCodeRequest(r, o.Decoder())
|
req, err := ParseDeviceCodeRequest(r, o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config := o.DeviceAuthorization()
|
config := o.DeviceAuthorization()
|
||||||
|
|
||||||
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
|
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
|
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
|
||||||
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response := &oidc.DeviceAuthorizationResponse{
|
response := &oidc.DeviceAuthorizationResponse{
|
||||||
|
@ -95,19 +92,22 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
||||||
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
|
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
|
||||||
|
|
||||||
httphelper.MarshalJSON(w, response)
|
httphelper.MarshalJSON(w, response)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) {
|
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
||||||
if err := r.ParseForm(); err != nil {
|
clientID, _, err := ClientIDFromRequest(r, o)
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
devReq := new(oidc.DeviceAuthorizationRequest)
|
req := new(oidc.DeviceAuthorizationRequest)
|
||||||
if err := decoder.Decode(devReq, r.Form); err != nil {
|
if err := o.Decoder().Decode(req, r.Form); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
|
||||||
}
|
}
|
||||||
|
req.ClientID = clientID
|
||||||
|
|
||||||
return devReq, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 16 bytes gives 128 bit of entropy.
|
// 16 bytes gives 128 bit of entropy.
|
||||||
|
@ -167,35 +167,54 @@ func (r *deviceAccessTokenRequest) GetScopes() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
req := new(oidc.DeviceAccessTokenRequest)
|
if err := deviceAccessToken(w, r, exchanger); err != nil {
|
||||||
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
|
|
||||||
RequestError(w, r, err)
|
RequestError(w, r, err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
|
||||||
// use a limited context timeout shorter as the default
|
// use a limited context timeout shorter as the default
|
||||||
// poll interval of 5 seconds.
|
// poll interval of 5 seconds.
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
client, err := ClientFromRequest(r, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
|
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRequest := &deviceAccessTokenRequest{
|
tokenRequest := &deviceAccessTokenRequest{
|
||||||
subject: state.Subject,
|
subject: state.Subject,
|
||||||
audience: []string{req.ClientID},
|
audience: []string{req.ClientID},
|
||||||
scopes: state.Scopes,
|
scopes: state.Scopes,
|
||||||
}
|
}
|
||||||
|
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, &jwtProfileClient{id: req.ClientID})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RequestError(w, r, err)
|
return err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
httphelper.MarshalJSON(w, resp)
|
httphelper.MarshalJSON(w, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
|
||||||
|
req := new(struct {
|
||||||
|
oidc.DeviceAccessTokenRequest
|
||||||
|
})
|
||||||
|
err := exchanger.Decoder().Decode(req, r.PostForm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &req.DeviceAccessTokenRequest, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
|
@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
|
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
|
||||||
err = r.ParseForm()
|
clientID, authenticated, err := ClientIDFromRequest(r, introspector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errors.New("unable to parse request")
|
return "", "", err
|
||||||
}
|
}
|
||||||
req := new(struct {
|
if !authenticated {
|
||||||
oidc.IntrospectionRequest
|
return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
|
||||||
oidc.ClientAssertionParams
|
}
|
||||||
})
|
|
||||||
|
req := new(oidc.IntrospectionRequest)
|
||||||
err = introspector.Decoder().Decode(req, r.Form)
|
err = introspector.Decoder().Decode(req, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errors.New("unable to parse request")
|
return "", "", errors.New("unable to parse request")
|
||||||
}
|
}
|
||||||
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
|
|
||||||
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
|
return req.Token, clientID, nil
|
||||||
if err == nil {
|
|
||||||
return req.Token, profile.Issuer, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
clientID, clientSecret, ok := r.BasicAuth()
|
|
||||||
if ok {
|
|
||||||
clientID, err = url.QueryUnescape(clientID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", errors.New("invalid basic auth header")
|
|
||||||
}
|
|
||||||
clientSecret, err = url.QueryUnescape(clientSecret)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", errors.New("invalid basic auth header")
|
|
||||||
}
|
|
||||||
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return req.Token, clientID, nil
|
|
||||||
}
|
|
||||||
return "", "", errors.New("invalid authorization")
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue