From 37b5de0e821cbaa29dd9cf56ef7f38fc77ae29d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 18 Aug 2023 16:03:51 +0300 Subject: [PATCH] fix(op): omit empty state from code flow redirect (#428) * chore(op): reproduce issue #415 * fix(op): omit empty state from code flow redirect Add test cases to reproduce the original bug, and it's resolution. closes #415 --- pkg/op/auth_request.go | 8 +-- pkg/op/auth_request_test.go | 131 +++++++++++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 5 deletions(-) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index c264605..5621951 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -448,11 +448,11 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques return } codeResponse := struct { - code string - state string + Code string `schema:"code"` + State string `schema:"state,omitempty"` }{ - code: code, - state: authReq.GetState(), + Code: code, + State: authReq.GetState(), } callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) if err != nil { diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 2bba4e7..1fadffc 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -3,6 +3,7 @@ package op_test import ( "context" "errors" + "io" "net/http" "net/http/httptest" "net/url" @@ -13,7 +14,7 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - + "github.com/zitadel/oidc/v2/example/server/storage" httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" @@ -942,3 +943,131 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { } return nil } + +// mockCrypto implements the op.Crypto interface +// and in always equals out. (It doesn't crypt anything). +// When returnErr != nil, that error is always returned instread. +type mockCrypto struct { + returnErr error +} + +func (c *mockCrypto) Encrypt(s string) (string, error) { + if c.returnErr != nil { + return "", c.returnErr + } + return s, nil +} + +func (c *mockCrypto) Decrypt(s string) (string, error) { + if c.returnErr != nil { + return "", c.returnErr + } + return s, nil +} + +func TestAuthResponseCode(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantCode int + wantLocationHeader string + wantBody string + } + tests := []struct { + name string + args args + res res + }{ + { + name: "create code error", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1&state=state1", + wantBody: "", + }, + }, + { + name: "success without state", // reproduce issue #415 + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(context.Background(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1", + wantBody: "", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/auth/callback/", nil) + w := httptest.NewRecorder() + op.AuthResponseCode(w, r, tt.args.authReq, tt.args.authorizer(t)) + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, tt.res.wantCode, resp.StatusCode) + assert.Equal(t, tt.res.wantLocationHeader, resp.Header.Get("Location")) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tt.res.wantBody, string(body)) + }) + } +}