feat: enhance authentication response handling
- Introduced CodeResponseType struct to encapsulate response data. - Added handleFormPostResponse and handleRedirectResponse functions to manage different response modes. - Created BuildAuthResponseCodeResponsePayload and BuildAuthResponseCallbackURL functions for better modularity in response generation.
This commit is contained in:
parent
aeda5d7178
commit
9a6863a511
2 changed files with 410 additions and 27 deletions
|
@ -62,6 +62,12 @@ type AuthorizeValidator interface {
|
|||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
||||
}
|
||||
|
||||
type CodeResponseType struct {
|
||||
Code string `schema:"code"`
|
||||
State string `schema:"state,omitempty"`
|
||||
SessionState string `schema:"session_state,omitempty"`
|
||||
}
|
||||
|
||||
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Authorize(w, r, authorizer)
|
||||
|
@ -477,48 +483,70 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
|||
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||
}
|
||||
|
||||
// AuthResponseCode creates the successful code authentication response
|
||||
// AuthResponseCode handles the creation of a successful authentication response using an authorization code
|
||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
ctx, span := tracer.Start(r.Context(), "AuthResponseCode")
|
||||
r = r.WithContext(ctx)
|
||||
defer span.End()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
var err error
|
||||
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
||||
err = handleFormPostResponse(w, r, authReq, authorizer)
|
||||
} else {
|
||||
err = handleRedirectResponse(w, r, authReq, authorizer)
|
||||
}
|
||||
|
||||
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
var sessionState string
|
||||
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
|
||||
if ok {
|
||||
}
|
||||
|
||||
// handleFormPostResponse processes the authentication response using form post method
|
||||
func handleFormPostResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
|
||||
codeResponse, err := BuildAuthResponseCodeResponsePayload(r.Context(), authReq, authorizer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return AuthResponseFormPost(w, authReq.GetRedirectURI(), codeResponse, authorizer.Encoder())
|
||||
}
|
||||
|
||||
// handleRedirectResponse processes the authentication response using the redirect method
|
||||
func handleRedirectResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
|
||||
callbackURL, err := BuildAuthResponseCallbackURL(r.Context(), authReq, authorizer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildAuthResponseCodeResponsePayload generates the authorization code response payload for the authentication request
|
||||
func BuildAuthResponseCodeResponsePayload(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (*CodeResponseType, error) {
|
||||
code, err := CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionState := ""
|
||||
if authRequestSessionState, ok := authReq.(AuthRequestSessionState); ok {
|
||||
sessionState = authRequestSessionState.GetSessionState()
|
||||
}
|
||||
codeResponse := struct {
|
||||
Code string `schema:"code"`
|
||||
State string `schema:"state,omitempty"`
|
||||
SessionState string `schema:"session_state,omitempty"`
|
||||
}{
|
||||
|
||||
return &CodeResponseType{
|
||||
Code: code,
|
||||
State: authReq.GetState(),
|
||||
SessionState: sessionState,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
||||
err := AuthResponseFormPost(w, authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||
// BuildAuthResponseCallbackURL generates the callback URL for a successful authorization code response
|
||||
func BuildAuthResponseCallbackURL(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (string, error) {
|
||||
codeResponse, err := BuildAuthResponseCodeResponsePayload(ctx, authReq, authorizer)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
return "", err
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
|
||||
return AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), codeResponse, authorizer.Encoder())
|
||||
}
|
||||
|
||||
// AuthResponseToken creates the successful token(s) authentication response
|
||||
|
|
|
@ -1225,6 +1225,133 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthResponseCodeResponsePayload(t *testing.T) {
|
||||
type args struct {
|
||||
authReq op.AuthRequest
|
||||
authorizer func(*testing.T) op.Authorizer
|
||||
}
|
||||
type res struct {
|
||||
wantCode string
|
||||
wantState string
|
||||
wantSessionState string
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "create code error",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{
|
||||
returnErr: io.ErrClosedPipe,
|
||||
})
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
TransferState: "state1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantCode: "id1",
|
||||
wantState: "state1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success without state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
TransferState: "",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantCode: "id1",
|
||||
wantState: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with session_state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequestWithSessionState{
|
||||
AuthRequest: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
TransferState: "state1",
|
||||
},
|
||||
SessionState: "session_state1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantCode: "id1",
|
||||
wantState: "state1",
|
||||
wantSessionState: "session_state1",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := op.BuildAuthResponseCodeResponsePayload(context.Background(), tt.args.authReq, tt.args.authorizer(t))
|
||||
if tt.res.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.res.wantCode, got.Code)
|
||||
assert.Equal(t, tt.res.wantState, got.State)
|
||||
assert.Equal(t, tt.res.wantSessionState, got.SessionState)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||
token, _ := tu.ValidIDToken()
|
||||
tests := []struct {
|
||||
|
@ -1255,3 +1382,231 @@ func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthResponseCallbackURL(t *testing.T) {
|
||||
type args struct {
|
||||
authReq op.AuthRequest
|
||||
authorizer func(*testing.T) op.Authorizer
|
||||
}
|
||||
type res struct {
|
||||
wantURL string
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "error when generating code response",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{
|
||||
returnErr: io.ErrClosedPipe,
|
||||
})
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error when generating callback URL",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "://invalid-url",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "https://example.com/callback",
|
||||
TransferState: "state1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantURL: "https://example.com/callback?code=id1&state=state1",
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success without state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "https://example.com/callback",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantURL: "https://example.com/callback?code=id1",
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with session_state",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequestWithSessionState{
|
||||
AuthRequest: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "https://example.com/callback",
|
||||
TransferState: "state1",
|
||||
},
|
||||
SessionState: "session_state1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantURL: "https://example.com/callback?code=id1&session_state=session_state1&state=state1",
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with existing query parameters",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "https://example.com/callback?param=value",
|
||||
TransferState: "state1",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantURL: "https://example.com/callback?param=value&code=id1&state=state1",
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with fragment response mode",
|
||||
args: args{
|
||||
authReq: &storage.AuthRequest{
|
||||
ID: "id1",
|
||||
CallbackURI: "https://example.com/callback",
|
||||
TransferState: "state1",
|
||||
ResponseMode: "fragment",
|
||||
},
|
||||
authorizer: func(t *testing.T) op.Authorizer {
|
||||
ctrl := gomock.NewController(t)
|
||||
storage := mock.NewMockStorage(ctrl)
|
||||
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||
|
||||
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||
authorizer.EXPECT().Storage().Return(storage)
|
||||
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||
return authorizer
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
wantURL: "https://example.com/callback#code=id1&state=state1",
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := op.BuildAuthResponseCallbackURL(context.Background(), tt.args.authReq, tt.args.authorizer(t))
|
||||
if tt.res.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.res.wantURL != "" {
|
||||
// Parse the URLs to compare components instead of direct string comparison
|
||||
expectedURL, err := url.Parse(tt.res.wantURL)
|
||||
require.NoError(t, err)
|
||||
actualURL, err := url.Parse(got)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Compare the base parts (scheme, host, path)
|
||||
assert.Equal(t, expectedURL.Scheme, actualURL.Scheme)
|
||||
assert.Equal(t, expectedURL.Host, actualURL.Host)
|
||||
assert.Equal(t, expectedURL.Path, actualURL.Path)
|
||||
|
||||
// Compare the fragment if any
|
||||
assert.Equal(t, expectedURL.Fragment, actualURL.Fragment)
|
||||
|
||||
// For query parameters, compare them independently of order
|
||||
expectedQuery := expectedURL.Query()
|
||||
actualQuery := actualURL.Query()
|
||||
|
||||
assert.Equal(t, len(expectedQuery), len(actualQuery), "Query parameter count does not match")
|
||||
|
||||
for key, expectedValues := range expectedQuery {
|
||||
actualValues, exists := actualQuery[key]
|
||||
assert.True(t, exists, "Expected query parameter %s not found", key)
|
||||
assert.ElementsMatch(t, expectedValues, actualValues, "Values for parameter %s don't match", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue