diff --git a/pkg/http/http.go b/pkg/http/http.go index 2512707..b0d1ef7 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -77,14 +77,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e return nil } -func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) { +func URLEncodeParams(resp interface{}, encoder Encoder) (url.Values, error) { values := make(map[string][]string) err := encoder.Encode(resp, values) if err != nil { - return "", err + return nil, err } - v := url.Values(values) - return v.Encode(), nil + return values, nil } func StartServer(ctx context.Context, port string) { diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 7b5b812..9c8c48b 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -396,6 +396,7 @@ type AccessTokenResponse struct { RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` + State string `json:"state,omitempty" schema:"state,omitempty"` } type JWTProfileAssertionClaims interface { diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 8a2b8b9..2ebedb5 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -465,18 +465,41 @@ func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { //AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values //depending on the response_mode and response_type func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response interface{}, encoder httphelper.Encoder) (string, error) { - params, err := httphelper.URLEncodeResponse(response, encoder) + uri, err := url.Parse(redirectURI) if err != nil { return "", oidc.ErrServerError().WithParent(err) } + params, err := httphelper.URLEncodeParams(response, encoder) + if err != nil { + return "", oidc.ErrServerError().WithParent(err) + } + //return explicitly requested mode if responseMode == oidc.ResponseModeQuery { - return redirectURI + "?" + params, nil + return mergeQueryParams(uri, params), nil } if responseMode == oidc.ResponseModeFragment { - return redirectURI + "#" + params, nil + return setFragment(uri, params), nil } - if responseType == "" || responseType == oidc.ResponseTypeCode { - return redirectURI + "?" + params, nil + //implicit must use fragment mode is not specified by client + if responseType == oidc.ResponseTypeIDToken || responseType == oidc.ResponseTypeIDTokenOnly { + return setFragment(uri, params), nil } - return redirectURI + "#" + params, nil + //if we get here it's code flow: defaults to query + return mergeQueryParams(uri, params), nil +} + +func setFragment(uri *url.URL, params url.Values) string { + uri.Fragment = params.Encode() + return uri.String() +} + +func mergeQueryParams(uri *url.URL, params url.Values) string { + queries := uri.Query() + for param, values := range params { + for _, value := range values { + queries.Add(param, value) + } + } + uri.RawQuery = queries.Encode() + return uri.String() } diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index f9ba4de..9023011 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -793,6 +793,90 @@ func TestAuthResponseURL(t *testing.T) { nil, }, }, + { + "with query", + args{ + "uri?param=value", + oidc.ResponseTypeCode, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?param=value&test=test", + nil, + }, + }, + { + "with query response type id token", + args{ + "uri?param=value", + oidc.ResponseTypeIDToken, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?param=value#test=test", + nil, + }, + }, + { + "with existing query", + args{ + "uri?test=value", + oidc.ResponseTypeCode, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?test=value&test=test", + nil, + }, + }, + { + "with existing query response type id token", + args{ + "uri?test=value", + oidc.ResponseTypeIDToken, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?test=value#test=test", + nil, + }, + }, + { + "with existing query and multiple values", + args{ + "uri?test=value", + oidc.ResponseTypeCode, + "", + map[string][]string{"test": {"test", "test2"}}, + &mockEncoder{}, + }, + res{ + "uri?test=value&test=test&test=test2", + nil, + }, + }, + { + "with existing query and multiple values response type id token", + args{ + "uri?test=value", + oidc.ResponseTypeIDToken, + "", + map[string][]string{"test": {"test", "test2"}}, + &mockEncoder{}, + }, + res{ + "uri?test=value#test=test&test=test2", + nil, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/op/token.go b/pkg/op/token.go index 7f6f599..3a72261 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -37,11 +37,13 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli return nil, err } + var state string if authRequest, ok := request.(AuthRequest); ok { err = creator.Storage().DeleteAuthRequest(ctx, authRequest.GetID()) if err != nil { return nil, err } + state = authRequest.GetState() } exp := uint64(validity.Seconds()) @@ -51,6 +53,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli RefreshToken: newRefreshToken, TokenType: oidc.BearerToken, ExpiresIn: exp, + State: state, }, nil }