diff --git a/example/client/api/api.go b/example/client/api/api.go index 8093b63..9f654a9 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/sirupsen/logrus" "github.com/zitadel/oidc/v2/pkg/client/rs" @@ -32,7 +32,7 @@ func main() { logrus.Fatalf("error creating provider %s", err.Error()) } - router := mux.NewRouter() + router := chi.NewRouter() // public url accessible without any authorization // will print `OK` and current timestamp @@ -73,9 +73,9 @@ func main() { http.Error(w, err.Error(), http.StatusForbidden) return } - params := mux.Vars(r) - requestedClaim := params["claim"] - requestedValue := params["value"] + requestedClaim := chi.URLParam(r, "claim") + requestedValue := chi.URLParam(r, "value") + value, ok := resp.Claims[requestedClaim].(string) if !ok || value == "" || value != requestedValue { http.Error(w, "claim does not match", http.StatusForbidden) diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go index e7c6e5f..eb5340e 100644 --- a/example/server/dynamic/login.go +++ b/example/server/dynamic/login.go @@ -6,7 +6,7 @@ import ( "html/template" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/zitadel/oidc/v2/pkg/op" ) @@ -43,7 +43,7 @@ var ( type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go index 783c75c..2bb6832 100644 --- a/example/server/dynamic/op.go +++ b/example/server/dynamic/op.go @@ -7,7 +7,7 @@ import ( "log" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "golang.org/x/text/language" "github.com/zitadel/oidc/v2/example/server/storage" @@ -47,7 +47,7 @@ func main() { //be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() //for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -76,7 +76,7 @@ func main() { //regardless of how many pages / steps there are in the process, the UI must be registered in the router, //so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) //is served on the correct path @@ -84,7 +84,7 @@ func main() { //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //then you would have to set the path prefix (/custom/path/): //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler())) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index ae2e8f2..59c2196 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -7,7 +7,7 @@ import ( "net/http" "net/url" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/gorilla/securecookie" "github.com/sirupsen/logrus" "github.com/zitadel/oidc/v2/pkg/op" @@ -23,14 +23,14 @@ type deviceLogin struct { cookie *securecookie.SecureCookie } -func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { +func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) { l := &deviceLogin{ storage: storage, cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), } - router.HandleFunc("", l.userCodeHandler) - router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) + router.HandleFunc("/", l.userCodeHandler) + router.Post("/login", l.loginHandler) router.HandleFunc("/confirm", l.confirmHandler) } diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go index c014c9a..9facb90 100644 --- a/example/server/exampleop/login.go +++ b/example/server/exampleop/login.go @@ -5,12 +5,12 @@ import ( "fmt" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi" ) type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -24,9 +24,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter() { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(l.checkLoginHandler) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 5604483..077244c 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "golang.org/x/text/language" "github.com/zitadel/oidc/v2/example/server/storage" @@ -34,12 +34,12 @@ type Storage interface { // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage) *mux.Router { +func SetupServer(issuer string, storage Storage) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -61,17 +61,18 @@ func SetupServer(issuer string, storage Storage) *mux.Router { // regardless of how many pages / steps there are in the process, the UI must be registered in the router, // so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) - router.PathPrefix("/device").Subrouter() - registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) + router.Route("/device", func(r chi.Router) { + registerDeviceAuth(storage, r) + }) // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) return router } diff --git a/go.mod b/go.mod index 7594264..a636250 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/zitadel/oidc/v2 go 1.18 require ( + github.com/go-chi/chi v1.5.4 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 github.com/gorilla/schema v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/jeremija/gosubmit v0.2.7 diff --git a/go.sum b/go.sum index e4e5c6c..a5ba642 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi v1.5.4 h1:QHdzF2szwjqVV4wmByUnTcsbIg7UGaQ0tPF2t5GcAIs= +github.com/go-chi/chi v1.5.4/go.mod h1:uaf8YgoFazUOkPBG7fxPftUylNumIev9awIWOENIuEg= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -19,8 +21,6 @@ github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -34,8 +34,6 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= @@ -50,9 +48,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -84,12 +79,6 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index b312098..4c48363 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "fmt" "net" "net/http" @@ -10,8 +11,6 @@ import ( "strings" "time" - "github.com/gorilla/mux" - httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" str "github.com/zitadel/oidc/v2/pkg/strings" @@ -405,13 +404,11 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * // AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - params := mux.Vars(r) - id := params["id"] - if id == "" { - AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder()) + id, err := ParseAuthorizeCallbackRequest(r) + if err != nil { + AuthRequestError(w, r, nil, err, authorizer.Encoder()) return } - authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { AuthRequestError(w, r, nil, err, authorizer.Encoder()) @@ -426,6 +423,17 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author AuthResponse(authReq, authorizer, w, r) } +func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { + if err = r.ParseForm(); err != nil { + return "", fmt.Errorf("cannot parse form: %w", err) + } + id = r.Form.Get("id") + if id == "" { + return "", errors.New("auth request callback is missing id") + } + return id, nil +} + // AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 7a9701b..542f2e2 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -12,7 +12,6 @@ import ( "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" @@ -967,3 +966,40 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { } return nil } + +func Test_parseAuthorizeCallbackRequest(t *testing.T) { + tests := []struct { + name string + url string + wantId string + wantErr bool + }{ + { + name: "parse error", + url: "/?id;=99", + wantErr: true, + }, + { + name: "missing id", + url: "/", + wantErr: true, + }, + { + name: "ok", + url: "/?id=99", + wantId: "99", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tt.url, nil) + gotId, err := op.ParseAuthorizeCallbackRequest(r) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantId, gotId) + }) + } +} diff --git a/pkg/op/op.go b/pkg/op/op.go index ecb753e..0536bbc 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -6,7 +6,7 @@ import ( "net/http" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi" "github.com/gorilla/schema" "github.com/rs/cors" "golang.org/x/text/language" @@ -68,6 +68,7 @@ var ( ) type OpenIDProvider interface { + http.Handler Configuration Storage() Storage Decoder() httphelper.Decoder @@ -77,20 +78,22 @@ type OpenIDProvider interface { Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } type HttpInterceptor func(http.Handler) http.Handler -func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { - router := mux.NewRouter() +func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { + router := chi.NewRouter() router.Use(cors.New(defaultCORSOptions).Handler) router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) - router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o)) + router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o)) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) @@ -184,7 +187,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR return nil, err } - o.httpHandler = CreateRouter(o, o.interceptors...) + o.Handler = CreateRouter(o, o.interceptors...) o.decoder = schema.NewDecoder() o.decoder.IgnoreUnknownKeys(true) @@ -200,6 +203,7 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR } type Provider struct { + http.Handler config *Config issuer IssuerFromRequest insecure bool @@ -207,7 +211,6 @@ type Provider struct { storage Storage keySet *openIDKeySet crypto Crypto - httpHandler http.Handler decoder *schema.Decoder encoder *schema.Encoder interceptors []HttpInterceptor @@ -372,8 +375,9 @@ func (o *Provider) Probes() []ProbesFn { } } +// Deprecated: Provider now implements http.Handler directly. func (o *Provider) HttpHandler() http.Handler { - return o.httpHandler + return o } type openIDKeySet struct { diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index ba3570b..8429212 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -365,7 +365,7 @@ func TestRoutes(t *testing.T) { } rec := httptest.NewRecorder() - testProvider.HttpHandler().ServeHTTP(rec, req) + testProvider.ServeHTTP(rec, req) resp := rec.Result() require.NoError(t, err)