From 5913c5a07482829532d831b168a82255cba0e8cc Mon Sep 17 00:00:00 2001 From: Masahito Osako <43847020+m11o@users.noreply.github.com> Date: Tue, 29 Apr 2025 23:17:28 +0900 Subject: [PATCH] feat: enhance authentication response handling (#728) - Introduced CodeResponseType struct to encapsulate response data. - Added handleFormPostResponse and handleRedirectResponse functions to manage different response modes. - Created BuildAuthResponseCodeResponsePayload and BuildAuthResponseCallbackURL functions for better modularity in response generation. --- pkg/op/auth_request.go | 82 ++++++--- pkg/op/auth_request_test.go | 355 ++++++++++++++++++++++++++++++++++++ 2 files changed, 410 insertions(+), 27 deletions(-) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 82f1b58..2c013aa 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -62,6 +62,12 @@ type AuthorizeValidator interface { ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error) } +type CodeResponseType struct { + Code string `schema:"code"` + State string `schema:"state,omitempty"` + SessionState string `schema:"session_state,omitempty"` +} + func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { Authorize(w, r, authorizer) @@ -477,48 +483,70 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri AuthResponseToken(w, r, authReq, authorizer, client) } -// AuthResponseCode creates the successful code authentication response +// AuthResponseCode handles the creation of a successful authentication response using an authorization code func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { ctx, span := tracer.Start(r.Context(), "AuthResponseCode") - r = r.WithContext(ctx) defer span.End() + r = r.WithContext(ctx) + + var err error + if authReq.GetResponseMode() == oidc.ResponseModeFormPost { + err = handleFormPostResponse(w, r, authReq, authorizer) + } else { + err = handleRedirectResponse(w, r, authReq, authorizer) + } - code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) - return } - var sessionState string - authRequestSessionState, ok := authReq.(AuthRequestSessionState) - if ok { +} + +// handleFormPostResponse processes the authentication response using form post method +func handleFormPostResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error { + codeResponse, err := BuildAuthResponseCodeResponsePayload(r.Context(), authReq, authorizer) + if err != nil { + return err + } + return AuthResponseFormPost(w, authReq.GetRedirectURI(), codeResponse, authorizer.Encoder()) +} + +// handleRedirectResponse processes the authentication response using the redirect method +func handleRedirectResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error { + callbackURL, err := BuildAuthResponseCallbackURL(r.Context(), authReq, authorizer) + if err != nil { + return err + } + http.Redirect(w, r, callbackURL, http.StatusFound) + return nil +} + +// BuildAuthResponseCodeResponsePayload generates the authorization code response payload for the authentication request +func BuildAuthResponseCodeResponsePayload(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (*CodeResponseType, error) { + code, err := CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto()) + if err != nil { + return nil, err + } + + sessionState := "" + if authRequestSessionState, ok := authReq.(AuthRequestSessionState); ok { sessionState = authRequestSessionState.GetSessionState() } - codeResponse := struct { - Code string `schema:"code"` - State string `schema:"state,omitempty"` - SessionState string `schema:"session_state,omitempty"` - }{ + + return &CodeResponseType{ Code: code, State: authReq.GetState(), SessionState: sessionState, - } + }, nil +} - if authReq.GetResponseMode() == oidc.ResponseModeFormPost { - err := AuthResponseFormPost(w, authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder()) - if err != nil { - AuthRequestError(w, r, authReq, err, authorizer) - return - } - - return - } - - callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) +// BuildAuthResponseCallbackURL generates the callback URL for a successful authorization code response +func BuildAuthResponseCallbackURL(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (string, error) { + codeResponse, err := BuildAuthResponseCodeResponsePayload(ctx, authReq, authorizer) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer) - return + return "", err } - http.Redirect(w, r, callback, http.StatusFound) + + return AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), codeResponse, authorizer.Encoder()) } // AuthResponseToken creates the successful token(s) authentication response diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 4878f5e..f0c4ef1 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -1225,6 +1225,133 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) { } } +func TestBuildAuthResponseCodeResponsePayload(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantCode string + wantState string + wantSessionState string + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "create code error", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "state1", + }, + }, + { + name: "success without state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "", + }, + }, + { + name: "success with session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "state1", + wantSessionState: "session_state1", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.BuildAuthResponseCodeResponsePayload(context.Background(), tt.args.authReq, tt.args.authorizer(t)) + if tt.res.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.res.wantCode, got.Code) + assert.Equal(t, tt.res.wantState, got.State) + assert.Equal(t, tt.res.wantSessionState, got.SessionState) + }) + } +} + func TestValidateAuthReqIDTokenHint(t *testing.T) { token, _ := tu.ValidIDToken() tests := []struct { @@ -1255,3 +1382,231 @@ func TestValidateAuthReqIDTokenHint(t *testing.T) { }) } } + +func TestBuildAuthResponseCallbackURL(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantURL string + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "error when generating code response", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "error when generating callback URL", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "://invalid-url", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1&state=state1", + wantErr: false, + }, + }, + { + name: "success without state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1", + wantErr: false, + }, + }, + { + name: "success with session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1&session_state=session_state1&state=state1", + wantErr: false, + }, + }, + { + name: "success with existing query parameters", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback?param=value", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?param=value&code=id1&state=state1", + wantErr: false, + }, + }, + { + name: "success with fragment response mode", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + ResponseMode: "fragment", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback#code=id1&state=state1", + wantErr: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.BuildAuthResponseCallbackURL(context.Background(), tt.args.authReq, tt.args.authorizer(t)) + if tt.res.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + if tt.res.wantURL != "" { + // Parse the URLs to compare components instead of direct string comparison + expectedURL, err := url.Parse(tt.res.wantURL) + require.NoError(t, err) + actualURL, err := url.Parse(got) + require.NoError(t, err) + + // Compare the base parts (scheme, host, path) + assert.Equal(t, expectedURL.Scheme, actualURL.Scheme) + assert.Equal(t, expectedURL.Host, actualURL.Host) + assert.Equal(t, expectedURL.Path, actualURL.Path) + + // Compare the fragment if any + assert.Equal(t, expectedURL.Fragment, actualURL.Fragment) + + // For query parameters, compare them independently of order + expectedQuery := expectedURL.Query() + actualQuery := actualURL.Query() + + assert.Equal(t, len(expectedQuery), len(actualQuery), "Query parameter count does not match") + + for key, expectedValues := range expectedQuery { + actualValues, exists := actualQuery[key] + assert.True(t, exists, "Expected query parameter %s not found", key) + assert.ElementsMatch(t, expectedValues, actualValues, "Values for parameter %s don't match", key) + } + } + }) + } +}