diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7f9e3b6..a8ca0e0 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -469,13 +469,12 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques } if authReq.GetResponseMode() == oidc.ResponseModeFormPost { - res, err := AuthResponseFormPost(authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder()) + err := AuthResponseFormPost(w, authReq.GetRedirectURI(), &codeResponse, authorizer.Encoder()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) return } - res.WriteTo(w) return } @@ -501,13 +500,12 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque } if authReq.GetResponseMode() == oidc.ResponseModeFormPost { - res, err := AuthResponseFormPost(authReq.GetRedirectURI(), resp, authorizer.Encoder()) + err := AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer) return } - res.WriteTo(w) return } @@ -568,11 +566,11 @@ var formPostHtmlTemplate string var formPostTmpl = template.Must(template.New("form_post").Parse(formPostHtmlTemplate)) // AuthResponseFormPost responds a html page that automatically submits the form which contains the auth response parameters -func AuthResponseFormPost(redirectURI string, response any, encoder httphelper.Encoder) (*bytes.Buffer, error) { +func AuthResponseFormPost(res http.ResponseWriter, redirectURI string, response any, encoder httphelper.Encoder) error { values := make(map[string][]string) err := encoder.Encode(response, values) if err != nil { - return nil, oidc.ErrServerError().WithParent(err) + return oidc.ErrServerError().WithParent(err) } params := &struct { @@ -586,10 +584,14 @@ func AuthResponseFormPost(redirectURI string, response any, encoder httphelper.E var buf bytes.Buffer err = formPostTmpl.Execute(&buf, params) if err != nil { - return nil, oidc.ErrServerError().WithParent(err) + return oidc.ErrServerError().WithParent(err) } - return &buf, nil + res.Header().Set("Cache-Control", "no-store") + res.WriteHeader(http.StatusOK) + buf.WriteTo(res) + + return nil } func setFragment(uri *url.URL, params url.Values) string { diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 7a363ff..76cb00d 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -1027,9 +1027,10 @@ func TestAuthResponseCode(t *testing.T) { authorizer func(*testing.T) op.Authorizer } type res struct { - wantCode int - wantLocationHeader string - wantBody string + wantCode int + wantLocationHeader string + wantCacheControlHeader string + wantBody string } tests := []struct { name string @@ -1133,9 +1134,9 @@ func TestAuthResponseCode(t *testing.T) { }, }, res: res{ - wantCode: http.StatusOK, - wantLocationHeader: "", - wantBody: "\n\n
\n\n\n\n", + wantCode: http.StatusOK, + wantCacheControlHeader: "no-store", + wantBody: "\n\n\n\n\n\n", }, }, } @@ -1148,6 +1149,7 @@ func TestAuthResponseCode(t *testing.T) { defer resp.Body.Close() assert.Equal(t, tt.res.wantCode, resp.StatusCode) assert.Equal(t, tt.res.wantLocationHeader, resp.Header.Get("Location")) + assert.Equal(t, tt.res.wantCacheControlHeader, resp.Header.Get("Cache-Control")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, tt.res.wantBody, string(body))