Merge branch 'main' into end-session-parameters
This commit is contained in:
commit
710b14ab9a
8 changed files with 505 additions and 42 deletions
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
|
@ -27,7 +27,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
go-version: ${{ matrix.go }}
|
go-version: ${{ matrix.go }}
|
||||||
- run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
|
- run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
|
||||||
- uses: codecov/codecov-action@v5.4.2
|
- uses: codecov/codecov-action@v5.4.3
|
||||||
with:
|
with:
|
||||||
file: ./profile.cov
|
file: ./profile.cov
|
||||||
name: codecov-go
|
name: codecov-go
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -21,8 +21,8 @@ require (
|
||||||
github.com/zitadel/logging v0.6.2
|
github.com/zitadel/logging v0.6.2
|
||||||
github.com/zitadel/schema v1.3.1
|
github.com/zitadel/schema v1.3.1
|
||||||
go.opentelemetry.io/otel v1.29.0
|
go.opentelemetry.io/otel v1.29.0
|
||||||
golang.org/x/oauth2 v0.29.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
golang.org/x/text v0.24.0
|
golang.org/x/text v0.25.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
|
8
go.sum
8
go.sum
|
@ -73,8 +73,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
|
||||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
@ -88,8 +88,8 @@ golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
|
||||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
|
|
|
@ -62,6 +62,12 @@ type AuthorizeValidator interface {
|
||||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CodeResponseType struct {
|
||||||
|
Code string `schema:"code"`
|
||||||
|
State string `schema:"state,omitempty"`
|
||||||
|
SessionState string `schema:"session_state,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
Authorize(w, r, authorizer)
|
Authorize(w, r, authorizer)
|
||||||
|
@ -477,48 +483,70 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
|
||||||
AuthResponseToken(w, r, authReq, authorizer, client)
|
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthResponseCode creates the successful code authentication response
|
// AuthResponseCode handles the creation of a successful authentication response using an authorization code
|
||||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||||
ctx, span := tracer.Start(r.Context(), "AuthResponseCode")
|
ctx, span := tracer.Start(r.Context(), "AuthResponseCode")
|
||||||
r = r.WithContext(ctx)
|
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
||||||
|
err = handleFormPostResponse(w, r, authReq, authorizer)
|
||||||
|
} else {
|
||||||
|
err = handleRedirectResponse(w, r, authReq, authorizer)
|
||||||
|
}
|
||||||
|
|
||||||
code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer)
|
AuthRequestError(w, r, authReq, err, authorizer)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
var sessionState string
|
}
|
||||||
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
|
|
||||||
if ok {
|
// handleFormPostResponse processes the authentication response using form post method
|
||||||
|
func handleFormPostResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
|
||||||
|
codeResponse, err := BuildAuthResponseCodeResponsePayload(r.Context(), authReq, authorizer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return AuthResponseFormPost(w, authReq.GetRedirectURI(), codeResponse, authorizer.Encoder())
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRedirectResponse processes the authentication response using the redirect method
|
||||||
|
func handleRedirectResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
|
||||||
|
callbackURL, err := BuildAuthResponseCallbackURL(r.Context(), authReq, authorizer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthResponseCodeResponsePayload generates the authorization code response payload for the authentication request
|
||||||
|
func BuildAuthResponseCodeResponsePayload(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (*CodeResponseType, error) {
|
||||||
|
code, err := CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionState := ""
|
||||||
|
if authRequestSessionState, ok := authReq.(AuthRequestSessionState); ok {
|
||||||
sessionState = authRequestSessionState.GetSessionState()
|
sessionState = authRequestSessionState.GetSessionState()
|
||||||
}
|
}
|
||||||
codeResponse := struct {
|
|
||||||
Code string `schema:"code"`
|
return &CodeResponseType{
|
||||||
State string `schema:"state,omitempty"`
|
|
||||||
SessionState string `schema:"session_state,omitempty"`
|
|
||||||
}{
|
|
||||||
Code: code,
|
Code: code,
|
||||||
State: authReq.GetState(),
|
State: authReq.GetState(),
|
||||||
SessionState: sessionState,
|
SessionState: sessionState,
|
||||||
}
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
|
// BuildAuthResponseCallbackURL generates the callback URL for a successful authorization code response
|
||||||
err := AuthResponseFormPost(w, authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder())
|
func BuildAuthResponseCallbackURL(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (string, error) {
|
||||||
if err != nil {
|
codeResponse, err := BuildAuthResponseCodeResponsePayload(ctx, authReq, authorizer)
|
||||||
AuthRequestError(w, r, authReq, err, authorizer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer)
|
return "", err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, callback, http.StatusFound)
|
|
||||||
|
return AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), codeResponse, authorizer.Encoder())
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthResponseToken creates the successful token(s) authentication response
|
// AuthResponseToken creates the successful token(s) authentication response
|
||||||
|
|
|
@ -1225,6 +1225,133 @@ func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthResponseCodeResponsePayload(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
authReq op.AuthRequest
|
||||||
|
authorizer func(*testing.T) op.Authorizer
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
wantCode string
|
||||||
|
wantState string
|
||||||
|
wantSessionState string
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create code error",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{
|
||||||
|
returnErr: io.ErrClosedPipe,
|
||||||
|
})
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantCode: "id1",
|
||||||
|
wantState: "state1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success without state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
TransferState: "",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantCode: "id1",
|
||||||
|
wantState: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with session_state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequestWithSessionState{
|
||||||
|
AuthRequest: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
SessionState: "session_state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantCode: "id1",
|
||||||
|
wantState: "state1",
|
||||||
|
wantSessionState: "session_state1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := op.BuildAuthResponseCodeResponsePayload(context.Background(), tt.args.authReq, tt.args.authorizer(t))
|
||||||
|
if tt.res.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.res.wantCode, got.Code)
|
||||||
|
assert.Equal(t, tt.res.wantState, got.State)
|
||||||
|
assert.Equal(t, tt.res.wantSessionState, got.SessionState)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||||
token, _ := tu.ValidIDToken()
|
token, _ := tu.ValidIDToken()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
@ -1255,3 +1382,231 @@ func TestValidateAuthReqIDTokenHint(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthResponseCallbackURL(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
authReq op.AuthRequest
|
||||||
|
authorizer func(*testing.T) op.Authorizer
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
wantURL string
|
||||||
|
wantErr bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error when generating code response",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{
|
||||||
|
returnErr: io.ErrClosedPipe,
|
||||||
|
})
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error when generating callback URL",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "://invalid-url",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "https://example.com/callback",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantURL: "https://example.com/callback?code=id1&state=state1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success without state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "https://example.com/callback",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantURL: "https://example.com/callback?code=id1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with session_state",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequestWithSessionState{
|
||||||
|
AuthRequest: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "https://example.com/callback",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
SessionState: "session_state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantURL: "https://example.com/callback?code=id1&session_state=session_state1&state=state1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with existing query parameters",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "https://example.com/callback?param=value",
|
||||||
|
TransferState: "state1",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantURL: "https://example.com/callback?param=value&code=id1&state=state1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success with fragment response mode",
|
||||||
|
args: args{
|
||||||
|
authReq: &storage.AuthRequest{
|
||||||
|
ID: "id1",
|
||||||
|
CallbackURI: "https://example.com/callback",
|
||||||
|
TransferState: "state1",
|
||||||
|
ResponseMode: "fragment",
|
||||||
|
},
|
||||||
|
authorizer: func(t *testing.T) op.Authorizer {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
storage := mock.NewMockStorage(ctrl)
|
||||||
|
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
|
||||||
|
|
||||||
|
authorizer := mock.NewMockAuthorizer(ctrl)
|
||||||
|
authorizer.EXPECT().Storage().Return(storage)
|
||||||
|
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
|
||||||
|
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
|
||||||
|
return authorizer
|
||||||
|
},
|
||||||
|
},
|
||||||
|
res: res{
|
||||||
|
wantURL: "https://example.com/callback#code=id1&state=state1",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := op.BuildAuthResponseCallbackURL(context.Background(), tt.args.authReq, tt.args.authorizer(t))
|
||||||
|
if tt.res.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.res.wantURL != "" {
|
||||||
|
// Parse the URLs to compare components instead of direct string comparison
|
||||||
|
expectedURL, err := url.Parse(tt.res.wantURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
actualURL, err := url.Parse(got)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare the base parts (scheme, host, path)
|
||||||
|
assert.Equal(t, expectedURL.Scheme, actualURL.Scheme)
|
||||||
|
assert.Equal(t, expectedURL.Host, actualURL.Host)
|
||||||
|
assert.Equal(t, expectedURL.Path, actualURL.Path)
|
||||||
|
|
||||||
|
// Compare the fragment if any
|
||||||
|
assert.Equal(t, expectedURL.Fragment, actualURL.Fragment)
|
||||||
|
|
||||||
|
// For query parameters, compare them independently of order
|
||||||
|
expectedQuery := expectedURL.Query()
|
||||||
|
actualQuery := actualURL.Query()
|
||||||
|
|
||||||
|
assert.Equal(t, len(expectedQuery), len(actualQuery), "Query parameter count does not match")
|
||||||
|
|
||||||
|
for key, expectedValues := range expectedQuery {
|
||||||
|
actualValues, exists := actualQuery[key]
|
||||||
|
assert.True(t, exists, "Expected query parameter %s not found", key)
|
||||||
|
assert.ElementsMatch(t, expectedValues, actualValues, "Values for parameter %s don't match", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -80,12 +80,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
codeChallenge := request.GetCodeChallenge()
|
codeChallenge := request.GetCodeChallenge()
|
||||||
if codeChallenge != nil {
|
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge)
|
||||||
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge)
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
|
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
|
||||||
|
|
|
@ -132,11 +132,19 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string,
|
||||||
// AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
|
// AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
|
||||||
// code_challenge of the auth request (PKCE)
|
// code_challenge of the auth request (PKCE)
|
||||||
func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error {
|
func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error {
|
||||||
|
if challenge == nil {
|
||||||
|
if codeVerifier != "" {
|
||||||
|
return oidc.ErrInvalidRequest().WithDescription("code_verifier unexpectedly provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if codeVerifier == "" {
|
if codeVerifier == "" {
|
||||||
return oidc.ErrInvalidRequest().WithDescription("code_challenge required")
|
return oidc.ErrInvalidRequest().WithDescription("code_verifier required")
|
||||||
}
|
}
|
||||||
if !oidc.VerifyCodeChallenge(challenge, codeVerifier) {
|
if !oidc.VerifyCodeChallenge(challenge, codeVerifier) {
|
||||||
return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
|
return oidc.ErrInvalidGrant().WithDescription("invalid code_verifier")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
75
pkg/op/token_request_test.go
Normal file
75
pkg/op/token_request_test.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
package op_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorizeCodeChallenge(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
codeVerifier string
|
||||||
|
codeChallenge *oidc.CodeChallenge
|
||||||
|
want func(t *testing.T, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing both code_verifier and code_challenge",
|
||||||
|
codeVerifier: "",
|
||||||
|
codeChallenge: nil,
|
||||||
|
want: func(t *testing.T, err error) {
|
||||||
|
assert.Nil(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid code_verifier",
|
||||||
|
codeVerifier: "Hello World!",
|
||||||
|
codeChallenge: &oidc.CodeChallenge{
|
||||||
|
Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
|
||||||
|
Method: oidc.CodeChallengeMethodS256,
|
||||||
|
},
|
||||||
|
want: func(t *testing.T, err error) {
|
||||||
|
assert.Nil(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid code_verifier",
|
||||||
|
codeVerifier: "Hi World!",
|
||||||
|
codeChallenge: &oidc.CodeChallenge{
|
||||||
|
Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
|
||||||
|
Method: oidc.CodeChallengeMethodS256,
|
||||||
|
},
|
||||||
|
want: func(t *testing.T, err error) {
|
||||||
|
assert.ErrorContains(t, err, "invalid code_verifier")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "code_verifier provided without code_challenge",
|
||||||
|
codeVerifier: "code_verifier",
|
||||||
|
codeChallenge: nil,
|
||||||
|
want: func(t *testing.T, err error) {
|
||||||
|
assert.ErrorContains(t, err, "code_verifier unexpectedly provided")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty code_verifier",
|
||||||
|
codeVerifier: "",
|
||||||
|
codeChallenge: &oidc.CodeChallenge{
|
||||||
|
Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
|
||||||
|
Method: oidc.CodeChallengeMethodS256,
|
||||||
|
},
|
||||||
|
want: func(t *testing.T, err error) {
|
||||||
|
assert.ErrorContains(t, err, "code_verifier required")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := op.AuthorizeCodeChallenge(tt.codeVerifier, tt.codeChallenge)
|
||||||
|
|
||||||
|
tt.want(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue