chore: replace gorilla/mux with go-chi/chi (#332)
BREAKING CHANGE: The returned router from `op.CreateRouter()` is now a `chi.Router` Closes #301
This commit is contained in:
parent
62caf5dafe
commit
57fb9f77aa
12 changed files with 98 additions and 60 deletions
|
@ -2,6 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -10,8 +11,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v2/pkg/strings"
|
||||
|
@ -405,13 +404,11 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *
|
|||
|
||||
// AuthorizeCallback handles the callback after authentication in the Login UI
|
||||
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
params := mux.Vars(r)
|
||||
id := params["id"]
|
||||
if id == "" {
|
||||
AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder())
|
||||
id, err := ParseAuthorizeCallbackRequest(r)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
|
@ -426,6 +423,17 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
|
||||
if err = r.ParseForm(); err != nil {
|
||||
return "", fmt.Errorf("cannot parse form: %w", err)
|
||||
}
|
||||
id = r.Form.Get("id")
|
||||
if id == "" {
|
||||
return "", errors.New("auth request callback is missing id")
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// AuthResponse creates the successful authentication response (either code or tokens)
|
||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v2/pkg/http"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
|
@ -967,3 +966,40 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_parseAuthorizeCallbackRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantId string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "parse error",
|
||||
url: "/?id;=99",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing id",
|
||||
url: "/",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
url: "/?id=99",
|
||||
wantId: "99",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||
gotId, err := op.ParseAuthorizeCallbackRequest(r)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.wantId, gotId)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
18
pkg/op/op.go
18
pkg/op/op.go
|
@ -6,7 +6,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/gorilla/schema"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/text/language"
|
||||
|
@ -68,6 +68,7 @@ var (
|
|||
)
|
||||
|
||||
type OpenIDProvider interface {
|
||||
http.Handler
|
||||
Configuration
|
||||
Storage() Storage
|
||||
Decoder() httphelper.Decoder
|
||||
|
@ -77,20 +78,22 @@ type OpenIDProvider interface {
|
|||
Crypto() Crypto
|
||||
DefaultLogoutRedirectURI() string
|
||||
Probes() []ProbesFn
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
HttpHandler() http.Handler
|
||||
}
|
||||
|
||||
type HttpInterceptor func(http.Handler) http.Handler
|
||||
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||
router.HandleFunc(healthEndpoint, healthHandler)
|
||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
|
||||
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
|
||||
router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o))
|
||||
router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
|
||||
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
||||
|
@ -184,7 +187,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
return nil, err
|
||||
}
|
||||
|
||||
o.httpHandler = CreateRouter(o, o.interceptors...)
|
||||
o.Handler = CreateRouter(o, o.interceptors...)
|
||||
|
||||
o.decoder = schema.NewDecoder()
|
||||
o.decoder.IgnoreUnknownKeys(true)
|
||||
|
@ -200,6 +203,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
}
|
||||
|
||||
type Provider struct {
|
||||
http.Handler
|
||||
config *Config
|
||||
issuer IssuerFromRequest
|
||||
insecure bool
|
||||
|
@ -207,7 +211,6 @@ type Provider struct {
|
|||
storage Storage
|
||||
keySet *openIDKeySet
|
||||
crypto Crypto
|
||||
httpHandler http.Handler
|
||||
decoder *schema.Decoder
|
||||
encoder *schema.Encoder
|
||||
interceptors []HttpInterceptor
|
||||
|
@ -372,8 +375,9 @@ func (o *Provider) Probes() []ProbesFn {
|
|||
}
|
||||
}
|
||||
|
||||
// Deprecated: Provider now implements http.Handler directly.
|
||||
func (o *Provider) HttpHandler() http.Handler {
|
||||
return o.httpHandler
|
||||
return o
|
||||
}
|
||||
|
||||
type openIDKeySet struct {
|
||||
|
|
|
@ -365,7 +365,7 @@ func TestRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
testProvider.HttpHandler().ServeHTTP(rec, req)
|
||||
testProvider.ServeHTTP(rec, req)
|
||||
|
||||
resp := rec.Result()
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue