From f26e155208064740e814fa119945e20832f040bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Sat, 25 Feb 2023 00:31:22 +0100 Subject: [PATCH] extract client authentication from introspection reuse the client authentication code for device authorization and introspection. --- pkg/op/client.go | 122 +++++++++++ pkg/op/client_test.go | 392 +++++++++++++++++++++++++++++++++++ pkg/op/device.go | 77 ++++--- pkg/op/token_intospection.go | 38 +--- 4 files changed, 571 insertions(+), 58 deletions(-) create mode 100644 pkg/op/client_test.go diff --git a/pkg/op/client.go b/pkg/op/client.go index e8a3347..48c6241 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -1,8 +1,14 @@ package op import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" "time" + httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" ) @@ -57,3 +63,119 @@ func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseT func IsConfidentialType(c Client) bool { return c.ApplicationType() == ApplicationTypeWeb } + +var ( + ErrInvalidAuthHeader = errors.New("invalid basic auth header") + ErrNoClientCredentials = errors.New("no client credentials provided") + ErrMissingClientID = errors.New("client_id missing from request") +) + +type ClientJWTProfile interface { + JWTProfileVerifier(context.Context) JWTProfileVerifier +} + +func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { + if ca.ClientAssertion == "" { + return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + + profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx)) + if err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed") + } + return profile.Issuer, nil +} + +func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) { + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + clientID, err = url.QueryUnescape(clientID) + if err != nil { + return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader) + } + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader) + } + if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err) + } + return clientID, nil +} + +type ClientProvider interface { + Decoder() httphelper.Decoder + Storage() Storage +} + +type clientData struct { + ClientID string `schema:"client_id"` + oidc.ClientAssertionParams +} + +// ClientIDFromRequest parses the request form and tries to obtain the client ID +// and reports if it is authenticated, using a JWT or static client secrets over +// http basic auth. +// +// If the Provider implements IntrospectorJWTProfile and "client_assertion" is +// present in the form data, JWT assertion will be verified and the +// client ID is taken from there. +// If any of them is absent, basic auth is attempted. +// In absence of basic auth data, the unauthenticated client id from the form +// data is returned. +// +// If no client id can be obtained by any method, oidc.ErrInvalidClient +// is returned with ErrMissingClientID wrapped in it. +func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) { + err = r.ParseForm() + if err != nil { + return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) + } + + data := new(clientData) + if err = p.Decoder().Decode(data, r.PostForm); err != nil { + return "", false, err + } + + JWTProfile, ok := p.(ClientJWTProfile) + if ok { + clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile) + } + if !ok || errors.Is(err, ErrNoClientCredentials) { + clientID, err = ClientBasicAuth(r, p.Storage()) + } + if err == nil { + return clientID, true, nil + } + + if data.ClientID == "" { + return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID) + } + return data.ClientID, false, nil +} + +// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage. +// If the client id was not authenticated, the client from storage does not have +// oidc.AuthMethodNone set, an error is returned. +func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) { + clientID, authenticated, err := ClientIDFromRequest(r, p) + if err != nil { + return nil, err + } + + client, err := p.Storage().GetClientByClientID(r.Context(), clientID) + if err != nil { + return nil, err + } + + if !authenticated { + if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client? + return nil, oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials). + WithDescription(fmt.Sprintf("required client auth method: %s", m)) + } + } + + return client, err +} diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go new file mode 100644 index 0000000..f42c647 --- /dev/null +++ b/pkg/op/client_test.go @@ -0,0 +1,392 @@ +package op_test + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/golang/mock/gomock" + "github.com/gorilla/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httphelper "github.com/zitadel/oidc/v2/pkg/http" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/oidc/v2/pkg/op/mock" +) + +type testClientJWTProfile struct{} + +func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } + +func TestClientJWTAuth(t *testing.T) { + type args struct { + ctx context.Context + ca oidc.ClientAssertionParams + verifier op.ClientJWTProfile + } + tests := []struct { + name string + args args + wantClientID string + wantErr error + }{ + { + name: "empty assertion", + args: args{ + context.Background(), + oidc.ClientAssertionParams{}, + testClientJWTProfile{}, + }, + wantErr: op.ErrNoClientCredentials, + }, + { + name: "verification error", + args: args{ + context.Background(), + oidc.ClientAssertionParams{ + ClientAssertion: "foo", + }, + testClientJWTProfile{}, + }, + wantErr: oidc.ErrParse, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantClientID, gotClientID) + }) + } +} + +func TestClientBasicAuth(t *testing.T) { + errWrong := errors.New("wrong secret") + + type args struct { + username string + password string + } + tests := []struct { + name string + args *args + storage op.Storage + wantClientID string + wantErr error + }{ + { + name: "no args", + wantErr: op.ErrNoClientCredentials, + }, + { + name: "username unescape err", + args: &args{ + username: "%", + password: "bar", + }, + wantErr: op.ErrInvalidAuthHeader, + }, + { + name: "password unescape err", + args: &args{ + username: "foo", + password: "%", + }, + wantErr: op.ErrInvalidAuthHeader, + }, + { + name: "auth error", + args: &args{ + username: "foo", + password: "wrong", + }, + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong) + return s + }(), + wantErr: errWrong, + }, + { + name: "auth error", + args: &args{ + username: "foo", + password: "bar", + }, + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + return s + }(), + wantClientID: "foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/foo", nil) + if tt.args != nil { + r.SetBasicAuth(tt.args.username, tt.args.password) + } + + gotClientID, err := op.ClientBasicAuth(r, tt.storage) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantClientID, gotClientID) + }) + } +} + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrNoProgress +} + +type testClientProvider struct { + storage op.Storage +} + +func (testClientProvider) Decoder() httphelper.Decoder { + return schema.NewDecoder() +} + +func (p testClientProvider) Storage() op.Storage { + return p.storage +} + +func TestClientIDFromRequest(t *testing.T) { + type args struct { + body io.Reader + p op.ClientProvider + } + type basicAuth struct { + username string + password string + } + tests := []struct { + name string + args args + basicAuth *basicAuth + wantClientID string + wantAuthenticated bool + wantErr bool + }{ + { + name: "parse error", + args: args{ + body: errReader{}, + }, + wantErr: true, + }, + { + name: "unauthenticated", + args: args{ + body: strings.NewReader( + url.Values{ + "client_id": []string{"foo"}, + }.Encode(), + ), + p: testClientProvider{ + storage: mock.NewStorage(t), + }, + }, + wantClientID: "foo", + wantAuthenticated: false, + }, + { + name: "unauthenticated", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + return s + }(), + }, + }, + basicAuth: &basicAuth{ + username: "foo", + password: "bar", + }, + wantClientID: "foo", + wantAuthenticated: true, + }, + { + name: "missing client id", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: mock.NewStorage(t), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if tt.basicAuth != nil { + r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantClientID, gotClientID) + assert.Equal(t, tt.wantAuthenticated, gotAuthenticated) + }) + } +} + +func TestClientFromRequest(t *testing.T) { + publicClient := func() op.Client { + c := mock.NewMockClient(gomock.NewController(t)) + c.EXPECT().AuthMethod().Return(oidc.AuthMethodNone) + return c + } + privateClient := func() op.Client { + c := mock.NewMockClient(gomock.NewController(t)) + c.EXPECT().AuthMethod().Return(oidc.AuthMethodPrivateKeyJWT) + return c + } + + type args struct { + body io.Reader + p op.ClientProvider + } + type basicAuth struct { + username string + password string + } + tests := []struct { + name string + args args + basicAuth *basicAuth + wantClient bool + wantErr bool + }{ + { + name: "missing client id", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: mock.NewStorage(t), + }, + }, + wantErr: true, + }, + { + name: "get client error", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(nil, errors.New("something")) + return s + }(), + }, + }, + basicAuth: &basicAuth{ + username: "foo", + password: "bar", + }, + wantErr: true, + }, + { + name: "authenticated", + args: args{ + body: strings.NewReader( + url.Values{}.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(mock.NewClient(t), nil) + return s + }(), + }, + }, + basicAuth: &basicAuth{ + username: "foo", + password: "bar", + }, + wantClient: true, + }, + { + name: "public", + args: args{ + body: strings.NewReader( + url.Values{ + "client_id": []string{"foo"}, + }.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(publicClient(), nil) + return s + }(), + }, + }, + wantClient: true, + }, + { + name: "false public", + args: args{ + body: strings.NewReader( + url.Values{ + "client_id": []string{"foo"}, + }.Encode(), + ), + p: testClientProvider{ + storage: func() op.Storage { + s := mock.NewMockStorage(gomock.NewController(t)) + s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(privateClient(), nil) + return s + }(), + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if tt.basicAuth != nil { + r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + got, err := op.ClientFromRequest(r, tt.args.p) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + if tt.wantClient { + assert.NotNil(t, got) + } + }) + } +} diff --git a/pkg/op/device.go b/pkg/op/device.go index 48108a0..bac80a4 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -48,41 +48,38 @@ var ( func deviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - DeviceAuthorization(w, r, o) + if err := DeviceAuthorization(w, r, o); err != nil { + RequestError(w, r, err) + } } } -func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) { +func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { storage, err := assertDeviceStorage(o.Storage()) if err != nil { - RequestError(w, r, err) - return + return err } - req, err := ParseDeviceCodeRequest(r, o.Decoder()) + req, err := ParseDeviceCodeRequest(r, o) if err != nil { - RequestError(w, r, err) - return + return err } config := o.DeviceAuthorization() deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) if err != nil { - RequestError(w, r, err) - return + return err } userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount) if err != nil { - RequestError(w, r, err) - return + return err } expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second) err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) if err != nil { - RequestError(w, r, err) - return + return err } response := &oidc.DeviceAuthorizationResponse{ @@ -95,19 +92,22 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", endpoint, userCode) httphelper.MarshalJSON(w, response) + return nil } -func ParseDeviceCodeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.DeviceAuthorizationRequest, error) { - if err := r.ParseForm(); err != nil { - return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) +func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { + clientID, _, err := ClientIDFromRequest(r, o) + if err != nil { + return nil, err } - devReq := new(oidc.DeviceAuthorizationRequest) - if err := decoder.Decode(devReq, r.Form); err != nil { + req := new(oidc.DeviceAuthorizationRequest) + if err := o.Decoder().Decode(req, r.Form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err) } + req.ClientID = clientID - return devReq, nil + return req, nil } // 16 bytes gives 128 bit of entropy. @@ -167,35 +167,54 @@ func (r *deviceAccessTokenRequest) GetScopes() []string { } func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { - req := new(oidc.DeviceAccessTokenRequest) - if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil { + if err := deviceAccessToken(w, r, exchanger); err != nil { RequestError(w, r, err) - return } +} +func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error { // use a limited context timeout shorter as the default // poll interval of 5 seconds. ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second) defer cancel() + r = r.WithContext(ctx) + client, err := ClientFromRequest(r, exchanger) + if err != nil { + return err + } + req, err := ParseDeviceAccessTokenRequest(r, exchanger) + if err != nil { + return err + } state, err := CheckDeviceAuthorizationState(ctx, req, exchanger) if err != nil { - RequestError(w, r, err) - return + return err } - tokenRequest := &deviceAccessTokenRequest{ subject: state.Subject, audience: []string{req.ClientID}, scopes: state.Scopes, } - - resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, &jwtProfileClient{id: req.ClientID}) + resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) if err != nil { - RequestError(w, r, err) - return + return err } + httphelper.MarshalJSON(w, resp) + return nil +} + +func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) { + req := new(struct { + oidc.DeviceAccessTokenRequest + }) + err := exchanger.Decoder().Decode(req, r.PostForm) + if err != nil { + return nil, err + } + + return &req.DeviceAccessTokenRequest, err } func CheckDeviceAuthorizationState(ctx context.Context, req *oidc.DeviceAccessTokenRequest, exchanger Exchanger) (*DeviceAuthorizationState, error) { diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index dfc8954..e7ca7c4 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "net/url" httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" @@ -50,38 +49,19 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto } func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) { - err = r.ParseForm() + clientID, authenticated, err := ClientIDFromRequest(r, introspector) if err != nil { - return "", "", errors.New("unable to parse request") + return "", "", err } - req := new(struct { - oidc.IntrospectionRequest - oidc.ClientAssertionParams - }) + if !authenticated { + return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) + } + + req := new(oidc.IntrospectionRequest) err = introspector.Decoder().Decode(req, r.Form) if err != nil { return "", "", errors.New("unable to parse request") } - if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" { - profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context())) - if err == nil { - return req.Token, profile.Issuer, nil - } - } - clientID, clientSecret, ok := r.BasicAuth() - if ok { - clientID, err = url.QueryUnescape(clientID) - if err != nil { - return "", "", errors.New("invalid basic auth header") - } - clientSecret, err = url.QueryUnescape(clientSecret) - if err != nil { - return "", "", errors.New("invalid basic auth header") - } - if err := introspector.Storage().AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil { - return "", "", err - } - return req.Token, clientID, nil - } - return "", "", errors.New("invalid authorization") + + return req.Token, clientID, nil }