chore: replace gorilla/mux with go-chi/chi (#332)

BREAKING CHANGE:
The returned router from `op.CreateRouter()` is now a `chi.Router`

Closes #301
This commit is contained in:
Tim Möhlmann 2023-03-17 17:36:02 +02:00 committed by GitHub
parent 62caf5dafe
commit 57fb9f77aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 98 additions and 60 deletions

View file

@ -2,6 +2,7 @@ package op
import (
"context"
"errors"
"fmt"
"net"
"net/http"
@ -10,8 +11,6 @@ import (
"strings"
"time"
"github.com/gorilla/mux"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
str "github.com/zitadel/oidc/v2/pkg/strings"
@ -405,13 +404,11 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
// AuthorizeCallback handles the callback after authentication in the Login UI
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
params := mux.Vars(r)
id := params["id"]
if id == "" {
AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder())
id, err := ParseAuthorizeCallbackRequest(r)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
@ -426,6 +423,17 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
AuthResponse(authReq, authorizer, w, r)
}
func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
if err = r.ParseForm(); err != nil {
return "", fmt.Errorf("cannot parse form: %w", err)
}
id = r.Form.Get("id")
if id == "" {
return "", errors.New("auth request callback is missing id")
}
return id, nil
}
// AuthResponse creates the successful authentication response (either code or tokens)
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())

View file

@ -12,7 +12,6 @@ import (
"github.com/gorilla/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
@ -967,3 +966,40 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error {
}
return nil
}
func Test_parseAuthorizeCallbackRequest(t *testing.T) {
tests := []struct {
name string
url string
wantId string
wantErr bool
}{
{
name: "parse error",
url: "/?id;=99",
wantErr: true,
},
{
name: "missing id",
url: "/",
wantErr: true,
},
{
name: "ok",
url: "/?id=99",
wantId: "99",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
gotId, err := op.ParseAuthorizeCallbackRequest(r)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.wantId, gotId)
})
}
}

View file

@ -6,7 +6,7 @@ import (
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/go-chi/chi"
"github.com/gorilla/schema"
"github.com/rs/cors"
"golang.org/x/text/language"
@ -68,6 +68,7 @@ var (
)
type OpenIDProvider interface {
http.Handler
Configuration
Storage() Storage
Decoder() httphelper.Decoder
@ -77,20 +78,22 @@ type OpenIDProvider interface {
Crypto() Crypto
DefaultLogoutRedirectURI() string
Probes() []ProbesFn
// Deprecated: Provider now implements http.Handler directly.
HttpHandler() http.Handler
}
type HttpInterceptor func(http.Handler) http.Handler
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
router := mux.NewRouter()
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(intercept(o.IssuerFromRequest, interceptors...))
router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
@ -184,7 +187,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
return nil, err
}
o.httpHandler = CreateRouter(o, o.interceptors...)
o.Handler = CreateRouter(o, o.interceptors...)
o.decoder = schema.NewDecoder()
o.decoder.IgnoreUnknownKeys(true)
@ -200,6 +203,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
}
type Provider struct {
http.Handler
config *Config
issuer IssuerFromRequest
insecure bool
@ -207,7 +211,6 @@ type Provider struct {
storage Storage
keySet *openIDKeySet
crypto Crypto
httpHandler http.Handler
decoder *schema.Decoder
encoder *schema.Encoder
interceptors []HttpInterceptor
@ -372,8 +375,9 @@ func (o *Provider) Probes() []ProbesFn {
}
}
// Deprecated: Provider now implements http.Handler directly.
func (o *Provider) HttpHandler() http.Handler {
return o.httpHandler
return o
}
type openIDKeySet struct {

View file

@ -365,7 +365,7 @@ func TestRoutes(t *testing.T) {
}
rec := httptest.NewRecorder()
testProvider.HttpHandler().ServeHTTP(rec, req)
testProvider.ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)