extract client authentication from introspection

reuse the client authentication code for device authorization
and introspection.
This commit is contained in:
Tim Möhlmann 2023-02-25 00:31:22 +01:00
parent 0f9ec46aaa
commit f26e155208
4 changed files with 571 additions and 58 deletions

View file

@ -1,8 +1,14 @@
package op package op
import ( import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time" "time"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/oidc"
) )
@ -57,3 +63,119 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT
func IsConfidentialType(c Client) bool { func IsConfidentialType(c Client) bool {
return c.ApplicationType() == ApplicationTypeWeb return c.ApplicationType() == ApplicationTypeWeb
} }
var (
ErrInvalidAuthHeader = errors.New("invalid basic auth header")
ErrNoClientCredentials = errors.New("no client credentials provided")
ErrMissingClientID = errors.New("client_id missing from request")
)
type ClientJWTProfile interface {
JWTProfileVerifier(context.Context) JWTProfileVerifier
}
func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
if ca.ClientAssertion == "" {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
if err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
}
return profile.Issuer, nil
}
func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
clientID, clientSecret, ok := r.BasicAuth()
if !ok {
return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
}
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
}
if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", oidc.ErrUnauthorizedClient().WithParent(err)
}
return clientID, nil
}
type ClientProvider interface {
Decoder() httphelper.Decoder
Storage() Storage
}
type clientData struct {
ClientID string `schema:"client_id"`
oidc.ClientAssertionParams
}
// ClientIDFromRequest parses the request form and tries to obtain the client ID
// and reports if it is authenticated, using a JWT or static client secrets over
// http basic auth.
//
// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
// present in the form data, JWT assertion will be verified and the
// client ID is taken from there.
// If any of them is absent, basic auth is attempted.
// In absence of basic auth data, the unauthenticated client id from the form
// data is returned.
//
// If no client id can be obtained by any method, oidc.ErrInvalidClient
// is returned with ErrMissingClientID wrapped in it.
func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
err = r.ParseForm()
if err != nil {
return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
data := new(clientData)
if err = p.Decoder().Decode(data, r.PostForm); err != nil {
return "", false, err
}
JWTProfile, ok := p.(ClientJWTProfile)
if ok {
clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
}
if !ok || errors.Is(err, ErrNoClientCredentials) {
clientID, err = ClientBasicAuth(r, p.Storage())
}
if err == nil {
return clientID, true, nil
}
if data.ClientID == "" {
return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
}
return data.ClientID, false, nil
}
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
// If the client id was not authenticated, the client from storage does not have
// oidc.AuthMethodNone set, an error is returned.
func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) {
clientID, authenticated, err := ClientIDFromRequest(r, p)
if err != nil {
return nil, err
}
client, err := p.Storage().GetClientByClientID(r.Context(), clientID)
if err != nil {
return nil, err
}
if !authenticated {
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
return nil, oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
WithDescription(fmt.Sprintf("required client auth method: %s", m))
}
}
return client, err
}

392
pkg/op/client_test.go Normal file
View file

@ -0,0 +1,392 @@
package op_test
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/oidc/v2/pkg/op/mock"
)
type testClientJWTProfile struct{}
func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil }
func TestClientJWTAuth(t *testing.T) {
type args struct {
ctx context.Context
ca oidc.ClientAssertionParams
verifier op.ClientJWTProfile
}
tests := []struct {
name string
args args
wantClientID string
wantErr error
}{
{
name: "empty assertion",
args: args{
context.Background(),
oidc.ClientAssertionParams{},
testClientJWTProfile{},
},
wantErr: op.ErrNoClientCredentials,
},
{
name: "verification error",
args: args{
context.Background(),
oidc.ClientAssertionParams{
ClientAssertion: "foo",
},
testClientJWTProfile{},
},
wantErr: oidc.ErrParse,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
func TestClientBasicAuth(t *testing.T) {
errWrong := errors.New("wrong secret")
type args struct {
username string
password string
}
tests := []struct {
name string
args *args
storage op.Storage
wantClientID string
wantErr error
}{
{
name: "no args",
wantErr: op.ErrNoClientCredentials,
},
{
name: "username unescape err",
args: &args{
username: "%",
password: "bar",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "password unescape err",
args: &args{
username: "foo",
password: "%",
},
wantErr: op.ErrInvalidAuthHeader,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "wrong",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong)
return s
}(),
wantErr: errWrong,
},
{
name: "auth error",
args: &args{
username: "foo",
password: "bar",
},
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
wantClientID: "foo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
if tt.args != nil {
r.SetBasicAuth(tt.args.username, tt.args.password)
}
gotClientID, err := op.ClientBasicAuth(r, tt.storage)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantClientID, gotClientID)
})
}
}
type errReader struct{}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrNoProgress
}
type testClientProvider struct {
storage op.Storage
}
func (testClientProvider) Decoder() httphelper.Decoder {
return schema.NewDecoder()
}
func (p testClientProvider) Storage() op.Storage {
return p.storage
}
func TestClientIDFromRequest(t *testing.T) {
type args struct {
body io.Reader
p op.ClientProvider
}
type basicAuth struct {
username string
password string
}
tests := []struct {
name string
args args
basicAuth *basicAuth
wantClientID string
wantAuthenticated bool
wantErr bool
}{
{
name: "parse error",
args: args{
body: errReader{},
},
wantErr: true,
},
{
name: "unauthenticated",
args: args{
body: strings.NewReader(
url.Values{
"client_id": []string{"foo"},
}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantClientID: "foo",
wantAuthenticated: false,
},
{
name: "unauthenticated",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
return s
}(),
},
},
basicAuth: &basicAuth{
username: "foo",
password: "bar",
},
wantClientID: "foo",
wantAuthenticated: true,
},
{
name: "missing client id",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if tt.basicAuth != nil {
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantClientID, gotClientID)
assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
})
}
}
func TestClientFromRequest(t *testing.T) {
publicClient := func() op.Client {
c := mock.NewMockClient(gomock.NewController(t))
c.EXPECT().AuthMethod().Return(oidc.AuthMethodNone)
return c
}
privateClient := func() op.Client {
c := mock.NewMockClient(gomock.NewController(t))
c.EXPECT().AuthMethod().Return(oidc.AuthMethodPrivateKeyJWT)
return c
}
type args struct {
body io.Reader
p op.ClientProvider
}
type basicAuth struct {
username string
password string
}
tests := []struct {
name string
args args
basicAuth *basicAuth
wantClient bool
wantErr bool
}{
{
name: "missing client id",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: mock.NewStorage(t),
},
},
wantErr: true,
},
{
name: "get client error",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(nil, errors.New("something"))
return s
}(),
},
},
basicAuth: &basicAuth{
username: "foo",
password: "bar",
},
wantErr: true,
},
{
name: "authenticated",
args: args{
body: strings.NewReader(
url.Values{}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(mock.NewClient(t), nil)
return s
}(),
},
},
basicAuth: &basicAuth{
username: "foo",
password: "bar",
},
wantClient: true,
},
{
name: "public",
args: args{
body: strings.NewReader(
url.Values{
"client_id": []string{"foo"},
}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(publicClient(), nil)
return s
}(),
},
},
wantClient: true,
},
{
name: "false public",
args: args{
body: strings.NewReader(
url.Values{
"client_id": []string{"foo"},
}.Encode(),
),
p: testClientProvider{
storage: func() op.Storage {
s := mock.NewMockStorage(gomock.NewController(t))
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(privateClient(), nil)
return s
}(),
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if tt.basicAuth != nil {
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
got, err := op.ClientFromRequest(r, tt.args.p)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if tt.wantClient {
assert.NotNil(t, got)
}
})
}
}

View file

@ -48,41 +48,38 @@ var (
func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
DeviceAuthorization(w, r, o) if err := DeviceAuthorization(w, r, o); err != nil {
RequestError(w, r, err)
}
} }
} }
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage()) storage, err := assertDeviceStorage(o.Storage())
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
req, err := ParseDeviceCodeRequest(r, o.Decoder()) req, err := ParseDeviceCodeRequest(r, o)
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
config := o.DeviceAuthorization() config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount) userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second) expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
response := &oidc.DeviceAuthorizationResponse{ response := &oidc.DeviceAuthorizationResponse{
@ -95,19 +92,22 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode)
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return nil
} }
func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) { func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
if err := r.ParseForm(); err != nil { clientID, _, err := ClientIDFromRequest(r, o)
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) if err != nil {
return nil, err
} }
devReq := new(oidc.DeviceAuthorizationRequest) req := new(oidc.DeviceAuthorizationRequest)
if err := decoder.Decode(devReq, r.Form); err != nil { if err := o.Decoder().Decode(req, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
} }
req.ClientID = clientID
return devReq, nil return req, nil
} }
// 16 bytes gives 128 bit of entropy. // 16 bytes gives 128 bit of entropy.
@ -167,35 +167,54 @@ func (r *deviceAccessTokenRequest) GetScopes() []string {
} }
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
req := new(oidc.DeviceAccessTokenRequest) if err := deviceAccessToken(w, r, exchanger); err != nil {
if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return }
} }
func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
// use a limited context timeout shorter as the default // use a limited context timeout shorter as the default
// poll interval of 5 seconds. // poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
defer cancel() defer cancel()
r = r.WithContext(ctx)
client, err := ClientFromRequest(r, exchanger)
if err != nil {
return err
}
req, err := ParseDeviceAccessTokenRequest(r, exchanger)
if err != nil {
return err
}
state, err := CheckDeviceAuthorizationState(ctx, req, exchanger) state, err := CheckDeviceAuthorizationState(ctx, req, exchanger)
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
tokenRequest := &deviceAccessTokenRequest{ tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject, subject: state.Subject,
audience: []string{req.ClientID}, audience: []string{req.ClientID},
scopes: state.Scopes, scopes: state.Scopes,
} }
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, &jwtProfileClient{id: req.ClientID})
if err != nil { if err != nil {
RequestError(w, r, err) return err
return
} }
httphelper.MarshalJSON(w, resp) httphelper.MarshalJSON(w, resp)
return nil
}
func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
req := new(struct {
oidc.DeviceAccessTokenRequest
})
err := exchanger.Decoder().Decode(req, r.PostForm)
if err != nil {
return nil, err
}
return &req.DeviceAccessTokenRequest, err
} }
func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) { func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) {

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"net/url"
httphelper "github.com/zitadel/oidc/v2/pkg/http" httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/oidc"
@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
} }
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) { func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
err = r.ParseForm() clientID, authenticated, err := ClientIDFromRequest(r, introspector)
if err != nil { if err != nil {
return "", "", errors.New("unable to parse request") return "", "", err
} }
req := new(struct { if !authenticated {
oidc.IntrospectionRequest return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
oidc.ClientAssertionParams }
})
req := new(oidc.IntrospectionRequest)
err = introspector.Decoder().Decode(req, r.Form) err = introspector.Decoder().Decode(req, r.Form)
if err != nil { if err != nil {
return "", "", errors.New("unable to parse request") return "", "", errors.New("unable to parse request")
} }
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
if err == nil {
return req.Token, profile.Issuer, nil
}
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", "", errors.New("invalid basic auth header")
}
if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
return "", "", err
}
return req.Token, clientID, nil return req.Token, clientID, nil
} }
return "", "", errors.New("invalid authorization")
}