feat(op): Server interface (#447)
* first draft of a new server interface * allow any response type * complete interface docs * refelct the format from the proposal * intermediate commit with some methods implemented * implement remaining token grant type methods * implement remaining server methods * error handling * rewrite auth request validation * define handlers, routes * input validation and concrete handlers * check if client credential client is authenticated * copy and modify the routes test for the legacy server * run integration tests against both Server and Provider * remove unuse ValidateAuthRequestV2 function * unit tests for error handling * cleanup tokenHandler * move server routest test * unit test authorize * handle client credentials in VerifyClient * change code exchange route test * finish http unit tests * review server interface docs and spelling * add withClient unit test * server options * cleanup unused GrantType method * resolve typo comments * make endpoints pointers to enable/disable them * jwt profile base work * jwt: correct the test expect --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
parent
daf82a5e04
commit
0f8a0585bf
28 changed files with 3654 additions and 126 deletions
|
@ -40,7 +40,7 @@ var counter atomic.Int64
|
|||
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
|
||||
//
|
||||
// Use one of the pre-made clients in storage/clients.go or register a new one.
|
||||
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router {
|
||||
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router {
|
||||
// the OpenID Provider requires a 32-byte key for (token) encryption
|
||||
// be sure to create a proper crypto random key and manage it securely!
|
||||
key := sha256.Sum256([]byte("test"))
|
||||
|
@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
|
|||
registerDeviceAuth(storage, r)
|
||||
})
|
||||
|
||||
handler := http.Handler(provider)
|
||||
if wrapServer {
|
||||
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
||||
}
|
||||
|
||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||
// is served on the correct path
|
||||
//
|
||||
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
|
||||
// then you would have to set the path prefix (/custom/path/)
|
||||
router.Mount("/", provider)
|
||||
router.Mount("/", handler)
|
||||
|
||||
return router
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func main() {
|
|||
Level: slog.LevelDebug,
|
||||
}),
|
||||
)
|
||||
router := exampleop.SetupServer(issuer, storage, logger)
|
||||
router := exampleop.SetupServer(issuer, storage, logger, false)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":" + port,
|
||||
|
|
|
@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
|
|||
authMethod: oidc.AuthMethodBasic,
|
||||
loginURL: defaultLoginURL,
|
||||
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
|
||||
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
|
||||
grantTypes: oidc.AllGrantTypes,
|
||||
accessTokenType: op.AccessTokenTypeBearer,
|
||||
devMode: false,
|
||||
idTokenUserinfoClaimsAssertion: false,
|
||||
|
|
|
@ -3,6 +3,7 @@ package client_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func TestRelyingPartySession(t *testing.T) {
|
||||
for _, wrapServer := range []bool{false, true} {
|
||||
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||
testRelyingPartySession(t, wrapServer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRelyingPartySession(t *testing.T, wrapServer bool) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
|
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestResourceServerTokenExchange(t *testing.T) {
|
||||
for _, wrapServer := range []bool{false, true} {
|
||||
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
|
||||
testResourceServerTokenExchange(t, wrapServer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
|
||||
t.Log("------- start example OP ------")
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
|
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
|
|||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
|
||||
|
||||
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
|
||||
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
|
||||
|
|
|
@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
}
|
||||
ctx := r.Context()
|
||||
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
|
||||
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return
|
||||
|
@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
|
|||
|
||||
// ParseRequestObject parse the `request` parameter, validates the token including the signature
|
||||
// and copies the token claims into the auth request
|
||||
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
|
||||
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
|
||||
requestObject := new(oidc.RequestObject)
|
||||
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if requestObject.Issuer != requestObject.ClientID {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
if !str.Contains(requestObject.Audience, issuer) {
|
||||
return authReq, oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest()
|
||||
}
|
||||
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
|
||||
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
|
||||
return authReq, err
|
||||
return err
|
||||
}
|
||||
CopyRequestObjectToAuthRequest(authReq, requestObject)
|
||||
return authReq, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
|
||||
|
|
|
@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
|||
}
|
||||
return data.ClientID, false, nil
|
||||
}
|
||||
|
||||
type ClientCredentials struct {
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
|
||||
ClientAssertion string `schema:"client_assertion"` // JWT
|
||||
ClientAssertionType string `schema:"client_assertion_type"`
|
||||
}
|
||||
|
|
|
@ -20,14 +20,14 @@ var (
|
|||
type Configuration interface {
|
||||
IssuerFromRequest(r *http.Request) string
|
||||
Insecure() bool
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
IntrospectionEndpoint() Endpoint
|
||||
UserinfoEndpoint() Endpoint
|
||||
RevocationEndpoint() Endpoint
|
||||
EndSessionEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
DeviceAuthorizationEndpoint() Endpoint
|
||||
AuthorizationEndpoint() *Endpoint
|
||||
TokenEndpoint() *Endpoint
|
||||
IntrospectionEndpoint() *Endpoint
|
||||
UserinfoEndpoint() *Endpoint
|
||||
RevocationEndpoint() *Endpoint
|
||||
EndSessionEndpoint() *Endpoint
|
||||
KeysEndpoint() *Endpoint
|
||||
DeviceAuthorizationEndpoint() *Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
CodeMethodS256Supported() bool
|
||||
|
|
|
@ -63,41 +63,51 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt
|
|||
}
|
||||
|
||||
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
|
||||
storage, err := assertDeviceStorage(o.Storage())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := ParseDeviceCodeRequest(r, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
|
||||
storage, err := assertDeviceStorage(o.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := o.DeviceAuthorization()
|
||||
|
||||
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
expires := time.Now().Add(config.Lifetime)
|
||||
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||
err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var verification *url.URL
|
||||
if config.UserFormURL != "" {
|
||||
if verification, err = url.Parse(config.UserFormURL); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
|
||||
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
|
||||
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
verification.Path = config.UserFormPath
|
||||
}
|
||||
|
@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
|||
|
||||
verification.RawQuery = "user_code=" + userCode
|
||||
response.VerificationURIComplete = verification.String()
|
||||
|
||||
httphelper.MarshalJSON(w, response)
|
||||
return nil
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
|
||||
|
|
|
@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{
|
|||
|
||||
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, CreateDiscoveryConfig(r, c, s))
|
||||
Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
|||
httphelper.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
|
||||
issuer := config.IssuerFromRequest(r)
|
||||
func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
|
||||
issuer := IssuerFromContext(ctx)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
|
||||
|
@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
|
|||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
|
||||
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
|
||||
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
|
||||
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
|
||||
ClaimsSupported: SupportedClaims(config),
|
||||
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
|
||||
UILocalesSupported: config.SupportedUILocales(),
|
||||
RequestParameterSupported: config.RequestObjectSupported(),
|
||||
}
|
||||
}
|
||||
|
||||
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
|
||||
issuer := IssuerFromContext(ctx)
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
|
||||
TokenEndpoint: endpoints.Token.Absolute(issuer),
|
||||
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
|
||||
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
|
||||
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
|
||||
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
|
||||
JwksURI: endpoints.JwksURI.Absolute(issuer),
|
||||
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
|
||||
ScopesSupported: Scopes(config),
|
||||
ResponseTypesSupported: ResponseTypes(config),
|
||||
GrantTypesSupported: GrantTypes(config),
|
||||
SubjectTypesSupported: SubjectTypes(config),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
|
||||
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
|
||||
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
|
||||
|
|
|
@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) {
|
|||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
request *http.Request
|
||||
c op.Configuration
|
||||
s op.DiscoverStorage
|
||||
ctx context.Context
|
||||
c op.Configuration
|
||||
s op.DiscoverStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
|
||||
got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,32 +1,46 @@
|
|||
package op
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
path string
|
||||
url string
|
||||
}
|
||||
|
||||
func NewEndpoint(path string) Endpoint {
|
||||
return Endpoint{path: path}
|
||||
func NewEndpoint(path string) *Endpoint {
|
||||
return &Endpoint{path: path}
|
||||
}
|
||||
|
||||
func NewEndpointWithURL(path, url string) Endpoint {
|
||||
return Endpoint{path: path, url: url}
|
||||
func NewEndpointWithURL(path, url string) *Endpoint {
|
||||
return &Endpoint{path: path, url: url}
|
||||
}
|
||||
|
||||
func (e Endpoint) Relative() string {
|
||||
func (e *Endpoint) Relative() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return relativeEndpoint(e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Absolute(host string) string {
|
||||
func (e *Endpoint) Absolute(host string) string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.url != "" {
|
||||
return e.url
|
||||
}
|
||||
return absoluteEndpoint(host, e.path)
|
||||
}
|
||||
|
||||
func (e Endpoint) Validate() error {
|
||||
var ErrNilEndpoint = errors.New("nil endpoint")
|
||||
|
||||
func (e *Endpoint) Validate() error {
|
||||
if e == nil {
|
||||
return ErrNilEndpoint
|
||||
}
|
||||
return nil // TODO:
|
||||
}
|
||||
|
||||
|
|
|
@ -3,13 +3,14 @@ package op_test
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
want string
|
||||
}{
|
||||
{
|
||||
|
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
|
|||
op.NewEndpointWithURL("/test", "http://test.com/test"),
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
e *op.Endpoint
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
|
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
args{"https://host"},
|
||||
"https://test.com/test",
|
||||
},
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
args{"https://host"},
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
|
|||
func TestEndpoint_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
wantErr bool
|
||||
e *op.Endpoint
|
||||
wantErr error
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
op.ErrNilEndpoint,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
err := tt.e.Validate()
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
101
pkg/op/error.go
101
pkg/op/error.go
|
@ -1,6 +1,9 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
|
@ -66,3 +69,101 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo
|
|||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, status)
|
||||
}
|
||||
|
||||
// TryErrorRedirect tries to handle an error by redirecting a client.
|
||||
// If this attempt fails, an error is returned that must be returned
|
||||
// to the client instead.
|
||||
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
|
||||
e := oidc.DefaultToServerError(parent, parent.Error())
|
||||
logger = logger.With("oidc_error", e)
|
||||
|
||||
if authReq == nil {
|
||||
logger.Log(ctx, e.LogLevel(), "auth request")
|
||||
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
|
||||
logger = logger.With("auth_request", logAuthReq)
|
||||
}
|
||||
|
||||
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
|
||||
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
|
||||
return nil, AsStatusError(e, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
e.State = authReq.GetState()
|
||||
var responseMode oidc.ResponseMode
|
||||
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
|
||||
responseMode = rm.GetResponseMode()
|
||||
}
|
||||
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, "auth response URL", "error", err)
|
||||
return nil, AsStatusError(err, http.StatusBadRequest)
|
||||
}
|
||||
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
|
||||
return NewRedirect(url), nil
|
||||
}
|
||||
|
||||
// StatusError wraps an error with a HTTP status code.
|
||||
// The status code is passed to the handler's writer.
|
||||
type StatusError struct {
|
||||
parent error
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// NewStatusError sets the parent and statusCode to a new StatusError.
|
||||
// It is recommended for parent to be an [oidc.Error].
|
||||
//
|
||||
// Typically implementations should only use this to signal something
|
||||
// very specific, like an internal server error.
|
||||
// If a returned error is not a StatusError, the framework
|
||||
// will set a statusCode based on what the standard specifies,
|
||||
// which is [http.StatusBadRequest] for most of the time.
|
||||
// If the error encountered can described clearly with a [oidc.Error],
|
||||
// do not use this function, as it might break standard rules!
|
||||
func NewStatusError(parent error, statusCode int) StatusError {
|
||||
return StatusError{
|
||||
parent: parent,
|
||||
statusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// AsStatusError unwraps a StatusError from err
|
||||
// and returns it unmodified if found.
|
||||
// If no StatuError was found, a new one is returned
|
||||
// with statusCode set to it as a default.
|
||||
func AsStatusError(err error, statusCode int) (target StatusError) {
|
||||
if errors.As(err, &target) {
|
||||
return target
|
||||
}
|
||||
return NewStatusError(err, statusCode)
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
|
||||
}
|
||||
|
||||
func (e StatusError) Unwrap() error {
|
||||
return e.parent
|
||||
}
|
||||
|
||||
func (e StatusError) Is(err error) bool {
|
||||
var target StatusError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return errors.Is(e.parent, target.parent) &&
|
||||
e.statusCode == target.statusCode
|
||||
}
|
||||
|
||||
// WriteError asserts for a StatusError containing an [oidc.Error].
|
||||
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
|
||||
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
|
||||
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||
statusError := AsStatusError(err, http.StatusBadRequest)
|
||||
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
|
||||
|
||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -275,3 +278,400 @@ func TestRequestError(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryErrorRedirect(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authReq ErrAuthRequest
|
||||
parent error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Redirect
|
||||
wantErr error
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "nil auth request",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: nil,
|
||||
parent: io.ErrClosedPipe,
|
||||
},
|
||||
wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth request",
|
||||
"time":"not",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, no redirect URI",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, redirect disabled",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
|
||||
},
|
||||
wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request: not redirecting",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request",
|
||||
"redirect_disabled":true
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request, url parse error",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "can't parse this!\n",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
wantErr: func() error {
|
||||
//lint:ignore SA1007 just recreating the error for testing
|
||||
_, err := url.Parse("can't parse this!\n")
|
||||
err = oidc.ErrServerError().WithParent(err)
|
||||
return NewStatusError(err, http.StatusBadRequest)
|
||||
}(),
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"auth response URL",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"can't parse this!\n",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"error":{
|
||||
"type":"server_error",
|
||||
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "auth request redirect",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
authReq: &oidc.AuthRequest{
|
||||
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
|
||||
ResponseType: "responseType",
|
||||
ClientID: "123",
|
||||
RedirectURI: "http://example.com/callback",
|
||||
State: "state1",
|
||||
ResponseMode: oidc.ResponseModeQuery,
|
||||
},
|
||||
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
|
||||
},
|
||||
want: &Redirect{
|
||||
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
|
||||
},
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"auth request redirect",
|
||||
"time":"not",
|
||||
"auth_request":{
|
||||
"client_id":"123",
|
||||
"redirect_uri":"http://example.com/callback",
|
||||
"response_type":"responseType",
|
||||
"scopes":"a b"
|
||||
},
|
||||
"oidc_error":{
|
||||
"description":"sign in",
|
||||
"type":"interaction_required"
|
||||
},
|
||||
"url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
encoder := schema.NewEncoder()
|
||||
|
||||
got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
|
||||
gotLog := logOut.String()
|
||||
t.Log(gotLog)
|
||||
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStatusError(t *testing.T) {
|
||||
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
|
||||
want := "Internal Server Error: io: read/write on closed pipe"
|
||||
got := fmt.Sprint(err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestAsStatusError(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
statusCode int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "already status error",
|
||||
args: args{
|
||||
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
want: "Internal Server Error: io: read/write on closed pipe",
|
||||
},
|
||||
{
|
||||
name: "oidc error",
|
||||
args: args{
|
||||
err: oidc.ErrAcrInvalid,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
want: "Bad Request: acr is invalid",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := AsStatusError(tt.args.err, tt.args.statusCode)
|
||||
got := fmt.Sprint(err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusError_Unwrap(t *testing.T) {
|
||||
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
require.ErrorIs(t, err, io.ErrClosedPipe)
|
||||
}
|
||||
|
||||
func TestStatusError_Is(t *testing.T) {
|
||||
type args struct {
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
args: args{err: nil},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
args: args{err: io.EOF},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other parent",
|
||||
args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "other status",
|
||||
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "same",
|
||||
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped",
|
||||
args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
|
||||
if got := e.Is(tt.args.err); got != tt.want {
|
||||
t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantLog string
|
||||
}{
|
||||
{
|
||||
name: "not a status or oidc error",
|
||||
err: io.ErrClosedPipe,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{
|
||||
"error":"server_error",
|
||||
"error_description":"io: read/write on closed pipe"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "status error w/o oidc",
|
||||
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{
|
||||
"error":"server_error",
|
||||
"error_description":"io: read/write on closed pipe"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"ERROR",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"io: read/write on closed pipe",
|
||||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oidc error w/o status",
|
||||
err: oidc.ErrInvalidRequest().WithDescription("oops"),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{
|
||||
"error":"invalid_request",
|
||||
"error_description":"oops"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"invalid_request"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "status with oidc error",
|
||||
err: NewStatusError(
|
||||
oidc.ErrUnauthorizedClient().WithDescription("oops"),
|
||||
http.StatusUnauthorized,
|
||||
),
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantBody: `{
|
||||
"error":"unauthorized_client",
|
||||
"error_description":"oops"
|
||||
}`,
|
||||
wantLog: `{
|
||||
"level":"WARN",
|
||||
"msg":"request error",
|
||||
"oidc_error":{
|
||||
"description":"oops",
|
||||
"type":"unauthorized_client"
|
||||
},
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOut := new(strings.Builder)
|
||||
logger := slog.New(
|
||||
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
|
||||
)
|
||||
r := httptest.NewRequest("GET", "/target", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
WriteError(w, r, tt.err, logger)
|
||||
res := w.Result()
|
||||
assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
|
||||
gotBody, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
|
||||
assert.JSONEq(t, tt.wantLog, logOut.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
|
|||
}
|
||||
|
||||
// AuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
|
|||
}
|
||||
|
||||
// DeviceAuthorizationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
|
|||
}
|
||||
|
||||
// EndSessionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EndSessionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
|
|||
}
|
||||
|
||||
// IntrospectionEndpoint mocks base method.
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
|
|||
}
|
||||
|
||||
// KeysEndpoint mocks base method.
|
||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
|
|||
}
|
||||
|
||||
// RevocationEndpoint mocks base method.
|
||||
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RevocationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
|
|||
}
|
||||
|
||||
// TokenEndpoint mocks base method.
|
||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
|
|||
}
|
||||
|
||||
// UserinfoEndpoint mocks base method.
|
||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
ret0, _ := ret[0].(*op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
|
67
pkg/op/op.go
67
pkg/op/op.go
|
@ -32,7 +32,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
DefaultEndpoints = &endpoints{
|
||||
DefaultEndpoints = &Endpoints{
|
||||
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
|
||||
Token: NewEndpoint(defaultTokenEndpoint),
|
||||
Introspection: NewEndpoint(defaultIntrospectEndpoint),
|
||||
|
@ -131,16 +131,17 @@ type Config struct {
|
|||
DeviceAuthorization DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
type endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
Introspection Endpoint
|
||||
Userinfo Endpoint
|
||||
Revocation Endpoint
|
||||
EndSession Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
DeviceAuthorization Endpoint
|
||||
// Endpoints defines endpoint routes.
|
||||
type Endpoints struct {
|
||||
Authorization *Endpoint
|
||||
Token *Endpoint
|
||||
Introspection *Endpoint
|
||||
Userinfo *Endpoint
|
||||
Revocation *Endpoint
|
||||
EndSession *Endpoint
|
||||
CheckSessionIframe *Endpoint
|
||||
JwksURI *Endpoint
|
||||
DeviceAuthorization *Endpoint
|
||||
}
|
||||
|
||||
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
|
||||
|
@ -212,7 +213,7 @@ type Provider struct {
|
|||
config *Config
|
||||
issuer IssuerFromRequest
|
||||
insecure bool
|
||||
endpoints *endpoints
|
||||
endpoints *Endpoints
|
||||
storage Storage
|
||||
keySet *openIDKeySet
|
||||
crypto Crypto
|
||||
|
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
|
|||
return o.insecure
|
||||
}
|
||||
|
||||
func (o *Provider) AuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) AuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (o *Provider) TokenEndpoint() Endpoint {
|
||||
func (o *Provider) TokenEndpoint() *Endpoint {
|
||||
return o.endpoints.Token
|
||||
}
|
||||
|
||||
func (o *Provider) IntrospectionEndpoint() Endpoint {
|
||||
func (o *Provider) IntrospectionEndpoint() *Endpoint {
|
||||
return o.endpoints.Introspection
|
||||
}
|
||||
|
||||
func (o *Provider) UserinfoEndpoint() Endpoint {
|
||||
func (o *Provider) UserinfoEndpoint() *Endpoint {
|
||||
return o.endpoints.Userinfo
|
||||
}
|
||||
|
||||
func (o *Provider) RevocationEndpoint() Endpoint {
|
||||
func (o *Provider) RevocationEndpoint() *Endpoint {
|
||||
return o.endpoints.Revocation
|
||||
}
|
||||
|
||||
func (o *Provider) EndSessionEndpoint() Endpoint {
|
||||
func (o *Provider) EndSessionEndpoint() *Endpoint {
|
||||
return o.endpoints.EndSession
|
||||
}
|
||||
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
|
||||
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
|
||||
return o.endpoints.DeviceAuthorization
|
||||
}
|
||||
|
||||
func (o *Provider) KeysEndpoint() Endpoint {
|
||||
func (o *Provider) KeysEndpoint() *Endpoint {
|
||||
return o.endpoints.JwksURI
|
||||
}
|
||||
|
||||
|
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
||||
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
|
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
|
||||
// WithCustomEndpoints sets multiple endpoints at once.
|
||||
// Non of the endpoints may be nil, or an error will
|
||||
// be returned when the Option used by the Provider.
|
||||
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
|
||||
return func(o *Provider) error {
|
||||
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
|
||||
if err := e.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
o.endpoints.Authorization = auth
|
||||
o.endpoints.Token = token
|
||||
o.endpoints.Userinfo = userInfo
|
||||
|
|
|
@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithCustomEndpoints(t *testing.T) {
|
||||
type args struct {
|
||||
auth *op.Endpoint
|
||||
token *op.Endpoint
|
||||
userInfo *op.Endpoint
|
||||
revocation *op.Endpoint
|
||||
endSession *op.Endpoint
|
||||
keys *op.Endpoint
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "all nil",
|
||||
args: args{},
|
||||
wantErr: op.ErrNilEndpoint,
|
||||
},
|
||||
{
|
||||
name: "all set",
|
||||
args: args{
|
||||
auth: op.NewEndpoint("/authorize"),
|
||||
token: op.NewEndpoint("/oauth/token"),
|
||||
userInfo: op.NewEndpoint("/userinfo"),
|
||||
revocation: op.NewEndpoint("/revoke"),
|
||||
endSession: op.NewEndpoint("/end_session"),
|
||||
keys: op.NewEndpoint("/keys"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)),
|
||||
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
|
||||
)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.wantErr != nil {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
|
||||
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
|
||||
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
|
||||
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
|
||||
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
|
||||
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
|
|||
}
|
||||
|
||||
func ok(w http.ResponseWriter) {
|
||||
httphelper.MarshalJSON(w, status{"ok"})
|
||||
httphelper.MarshalJSON(w, Status{"ok"})
|
||||
}
|
||||
|
||||
type status struct {
|
||||
type Status struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
|
346
pkg/op/server.go
Normal file
346
pkg/op/server.go
Normal file
|
@ -0,0 +1,346 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// Server describes the interface that needs to be implemented to serve
|
||||
// OpenID Connect and Oauth2 standard requests.
|
||||
//
|
||||
// Methods are called after the HTTP route is resolved and
|
||||
// the request body is parsed into the Request's Data field.
|
||||
// When a method is called, it can be assumed that required fields,
|
||||
// as described in their relevant standard, are validated already.
|
||||
// The Response Data field may be of any type to allow flexibility
|
||||
// to extend responses with custom fields. There are however requirements
|
||||
// in the standards regarding the response models. Where applicable
|
||||
// the method documentation gives a recommended type which can be used
|
||||
// directly or extended upon.
|
||||
//
|
||||
// The addition of new methods is not considered a breaking change
|
||||
// as defined by semver rules.
|
||||
// Implementations MUST embed [UnimplementedServer] to maintain
|
||||
// forward compatibility.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Server interface {
|
||||
// Health returns a status of "ok" once the Server is listening.
|
||||
// The recommended Response Data type is [Status].
|
||||
Health(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Ready returns a status of "ok" once all dependencies,
|
||||
// such as database storage, are ready.
|
||||
// An error can be returned to explain what is not ready.
|
||||
// The recommended Response Data type is [Status].
|
||||
Ready(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Discovery returns the OpenID Provider Configuration Information for this server.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
|
||||
// The recommended Response Data type is [oidc.DiscoveryConfiguration].
|
||||
Discovery(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// Keys serves the JWK set which the client can use verify signatures from the op.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
|
||||
// The recommended Response Data type is [jose.JSONWebKeySet].
|
||||
Keys(context.Context, *Request[struct{}]) (*Response, error)
|
||||
|
||||
// VerifyAuthRequest verifies the Auth Request and
|
||||
// adds the Client to the request.
|
||||
//
|
||||
// When the `request` field is populated with a
|
||||
// "Request Object" JWT, it needs to be Validated
|
||||
// and its claims overwrite any fields in the AuthRequest.
|
||||
// If the implementation does not support "Request Object",
|
||||
// it MUST return an [oidc.ErrRequestNotSupported].
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
|
||||
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
|
||||
|
||||
// Authorize initiates the authorization flow and redirects to a login page.
|
||||
// See the various https://openid.net/specs/openid-connect-core-1_0.html
|
||||
// authorize endpoint sections (one for each type of flow).
|
||||
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
|
||||
|
||||
// DeviceAuthorization initiates the device authorization flow.
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
|
||||
// The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
|
||||
DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
|
||||
|
||||
// VerifyClient is called on most oauth/token handlers to authenticate,
|
||||
// using either a secret (POST, Basic) or assertion (JWT).
|
||||
// If no secrets are provided, the client must be public.
|
||||
// This method is called before each method that takes a
|
||||
// [ClientRequest] argument.
|
||||
VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
|
||||
|
||||
// CodeExchange returns Tokens after an authorization code
|
||||
// is obtained in a successful Authorize flow.
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value authorization_code
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
|
||||
|
||||
// RefreshToken returns new Tokens after verifying a Refresh token.
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value refresh_token
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
|
||||
|
||||
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
|
||||
// https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
|
||||
|
||||
// TokenExchange handles the OAuth 2.0 token exchange grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
|
||||
// https://datatracker.ietf.org/doc/html/rfc8693
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
|
||||
|
||||
// ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value client_credentials
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
|
||||
|
||||
// DeviceToken handles the OAuth 2.0 Device Authorization Grant
|
||||
// It is called by the Token endpoint handler when
|
||||
// grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
|
||||
// It is typically called in a polling fashion and appropriate errors
|
||||
// should be returned to signal authorization_pending or access_denied etc.
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
|
||||
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
|
||||
// The recommended Response Data type is [oidc.AccessTokenResponse].
|
||||
DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
|
||||
|
||||
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7662
|
||||
// The recommended Response Data type is [oidc.IntrospectionResponse].
|
||||
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
|
||||
|
||||
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
// The recommended Response Data type is [oidc.UserInfo].
|
||||
UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
|
||||
|
||||
// Revocation handles token revocation using an access or refresh token.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7009
|
||||
// There are no response requirements. Data may remain empty.
|
||||
Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
|
||||
|
||||
// EndSession handles the OpenID Connect RP-Initiated Logout.
|
||||
// https://openid.net/specs/openid-connect-rpinitiated-1_0.html
|
||||
// There are no response requirements. Data may remain empty.
|
||||
EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
|
||||
|
||||
// mustImpl forces implementations to embed the UnimplementedServer for forward
|
||||
// compatibility with the interface.
|
||||
mustImpl()
|
||||
}
|
||||
|
||||
// Request contains the [http.Request] informational fields
|
||||
// and parsed Data from the request body (POST) or URL parameters (GET).
|
||||
// Data can be assumed to be validated according to the applicable
|
||||
// standard for the specific endpoints.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Request[T any] struct {
|
||||
Method string
|
||||
URL *url.URL
|
||||
Header http.Header
|
||||
Form url.Values
|
||||
PostForm url.Values
|
||||
Data *T
|
||||
}
|
||||
|
||||
func (r *Request[_]) path() string {
|
||||
return r.URL.Path
|
||||
}
|
||||
|
||||
func newRequest[T any](r *http.Request, data *T) *Request[T] {
|
||||
return &Request[T]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
PostForm: r.PostForm,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ClientRequest is a Request with a verified client attached to it.
|
||||
// Methods that receive this argument may assume the client was authenticated,
|
||||
// or verified to be a public client.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type ClientRequest[T any] struct {
|
||||
*Request[T]
|
||||
Client Client
|
||||
}
|
||||
|
||||
func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
|
||||
return &ClientRequest[T]{
|
||||
Request: newRequest[T](r, data),
|
||||
Client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Response object for most [Server] methods.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Response struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
Header http.Header
|
||||
|
||||
// Data will be JSON marshaled to
|
||||
// the response body.
|
||||
// We allow any type, so that implementations
|
||||
// can extend the standard types as they wish.
|
||||
// However, each method will recommend which
|
||||
// (base) type to use as model, in order to
|
||||
// be compliant with the standards.
|
||||
Data any
|
||||
}
|
||||
|
||||
// NewResponse creates a new response for data,
|
||||
// without custom headers.
|
||||
func NewResponse(data any) *Response {
|
||||
return &Response{
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (resp *Response) writeOut(w http.ResponseWriter) {
|
||||
gu.MapMerge(resp.Header, w.Header())
|
||||
httphelper.MarshalJSON(w, resp.Data)
|
||||
}
|
||||
|
||||
// Redirect is a special response type which will
|
||||
// initiate a [http.StatusFound] redirect.
|
||||
// The Params field will be encoded and set to the
|
||||
// URL's RawQuery field before building the URL.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type Redirect struct {
|
||||
// Header map will be merged with the
|
||||
// header on the [http.ResponseWriter].
|
||||
Header http.Header
|
||||
|
||||
URL string
|
||||
}
|
||||
|
||||
func NewRedirect(url string) *Redirect {
|
||||
return &Redirect{URL: url}
|
||||
}
|
||||
|
||||
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
|
||||
gu.MapMerge(r.Header, w.Header())
|
||||
http.Redirect(w, r, red.URL, http.StatusFound)
|
||||
}
|
||||
|
||||
type UnimplementedServer struct{}
|
||||
|
||||
// UnimplementedStatusCode is the status code returned for methods
|
||||
// that are not yet implemented.
|
||||
// Note that this means methods in the sense of the Go interface,
|
||||
// and not http methods covered by "501 Not Implemented".
|
||||
var UnimplementedStatusCode = http.StatusNotFound
|
||||
|
||||
func unimplementedError(r interface{ path() string }) StatusError {
|
||||
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
|
||||
return NewStatusError(err, UnimplementedStatusCode)
|
||||
}
|
||||
|
||||
func unimplementedGrantError(gt oidc.GrantType) StatusError {
|
||||
err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
|
||||
return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
}
|
||||
|
||||
func (UnimplementedServer) mustImpl() {}
|
||||
|
||||
func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||
if r.Data.RequestParam != "" {
|
||||
return nil, oidc.ErrRequestNotSupported()
|
||||
}
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeCode)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
480
pkg/op/server_http.go
Normal file
480
pkg/op/server_http.go
Normal file
|
@ -0,0 +1,480 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/rs/cors"
|
||||
"github.com/zitadel/logging"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
// RegisterServer registers an implementation of Server.
|
||||
// The resulting handler takes care of routing and request parsing,
|
||||
// with some basic validation of required fields.
|
||||
// The routes can be customized with [WithEndpoints].
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
|
||||
decoder := schema.NewDecoder()
|
||||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
ws := &webServer{
|
||||
server: server,
|
||||
endpoints: endpoints,
|
||||
decoder: decoder,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(ws)
|
||||
}
|
||||
|
||||
ws.createRouter()
|
||||
return ws
|
||||
}
|
||||
|
||||
type ServerOption func(s *webServer)
|
||||
|
||||
// WithHTTPMiddleware sets the passed middleware chain to the root of
|
||||
// the Server's router.
|
||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.middleware = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecoder overrides the default decoder,
|
||||
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
|
||||
func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.decoder = decoder
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallbackLogger overrides the fallback logger, which
|
||||
// is used when no logger was found in the context.
|
||||
// Defaults to [slog.Default].
|
||||
func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
type webServer struct {
|
||||
http.Handler
|
||||
server Server
|
||||
middleware []func(http.Handler) http.Handler
|
||||
endpoints Endpoints
|
||||
decoder httphelper.Decoder
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||
if logger, ok := logging.FromContext(ctx); ok {
|
||||
return logger
|
||||
}
|
||||
return s.logger
|
||||
}
|
||||
|
||||
func (s *webServer) createRouter() {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(s.middleware...)
|
||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||
|
||||
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
||||
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
||||
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
||||
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
||||
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||
s.Handler = router
|
||||
}
|
||||
|
||||
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
||||
if e != nil {
|
||||
router.HandleFunc(e.Relative(), hf)
|
||||
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||
}
|
||||
}
|
||||
|
||||
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
|
||||
|
||||
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
client, err := s.verifyRequestClient(r)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
|
||||
if !ValidateGrantType(client, grantType) {
|
||||
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
}
|
||||
handler(w, r, client)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
||||
if err = r.ParseForm(); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
cc := new(ClientCredentials)
|
||||
if err = s.decoder.Decode(cc, r.Form); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||
}
|
||||
// Basic auth takes precedence, so if set it overwrites the form data.
|
||||
if clientID, clientSecret, ok := r.BasicAuth(); ok {
|
||||
cc.ClientID, err = url.QueryUnescape(clientID)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
|
||||
}
|
||||
}
|
||||
if cc.ClientID == "" && cc.ClientAssertion == "" {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
|
||||
}
|
||||
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
|
||||
}
|
||||
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
Data: cc,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
redirect.writeOut(w, r)
|
||||
}
|
||||
|
||||
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||
cr, err := s.server.VerifyAuthRequest(ctx, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := cr.Data
|
||||
if authReq.RedirectURI == "" {
|
||||
return nil, ErrAuthReqMissingRedirectURI
|
||||
}
|
||||
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.server.Authorize(ctx, cr)
|
||||
}
|
||||
|
||||
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
|
||||
case oidc.GrantTypeCode:
|
||||
s.withClient(s.codeExchangeHandler)(w, r)
|
||||
case oidc.GrantTypeRefreshToken:
|
||||
s.withClient(s.refreshTokenHandler)(w, r)
|
||||
case oidc.GrantTypeClientCredentials:
|
||||
s.withClient(s.clientCredentialsHandler)(w, r)
|
||||
case oidc.GrantTypeBearer:
|
||||
s.jwtProfileHandler(w, r)
|
||||
case oidc.GrantTypeTokenExchange:
|
||||
s.withClient(s.tokenExchangeHandler)(w, r)
|
||||
case oidc.GrantTypeDeviceCode:
|
||||
s.withClient(s.deviceTokenHandler)(w, r)
|
||||
case "":
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
|
||||
default:
|
||||
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Assertion == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Code == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RedirectURI == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RefreshToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectToken == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.SubjectTokenType == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if !request.SubjectTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
||||
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.DeviceCode == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if token, err := getAccessToken(r); err == nil {
|
||||
request.AccessToken = token
|
||||
}
|
||||
if request.AccessToken == "" {
|
||||
err = NewStatusError(
|
||||
oidc.ErrInvalidRequest().WithDescription("access token missing"),
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if request.Token == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w, r)
|
||||
}
|
||||
|
||||
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp.writeOut(w)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
||||
dst := new(R)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
form := r.Form
|
||||
if postOnly {
|
||||
form = r.PostForm
|
||||
}
|
||||
if err := decoder.Decode(dst, form); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||
}
|
||||
return dst, nil
|
||||
}
|
345
pkg/op/server_http_routes_test.go
Normal file
345
pkg/op/server_http_routes_test.go
Normal file
|
@ -0,0 +1,345 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func jwtProfile() (string, error) {
|
||||
keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
|
||||
}
|
||||
|
||||
func TestServerRoutes(t *testing.T) {
|
||||
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
||||
|
||||
storage := testProvider.Storage().(routesTestStorage)
|
||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||
|
||||
client, err := storage.GetClientByClientID(ctx, "web")
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq := &oidc.AuthRequest{
|
||||
ClientID: client.GetID(),
|
||||
RedirectURI: "https://example.com",
|
||||
MaxAge: gu.Ptr[uint](300),
|
||||
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
|
||||
ResponseType: oidc.ResponseTypeCode,
|
||||
}
|
||||
|
||||
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
|
||||
require.NoError(t, err)
|
||||
storage.AuthRequestDone(authReq.GetID())
|
||||
|
||||
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
|
||||
require.NoError(t, err)
|
||||
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
|
||||
require.NoError(t, err)
|
||||
jwtProfileToken, err := jwtProfile()
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcAuthReq.IDTokenHint = idToken
|
||||
|
||||
serverURL, err := url.Parse(testIssuer)
|
||||
require.NoError(t, err)
|
||||
|
||||
type basicAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
basicAuth *basicAuth
|
||||
header map[string]string
|
||||
values map[string]string
|
||||
body map[string]string
|
||||
wantCode int
|
||||
headerContains map[string]string
|
||||
json string // test for exact json output
|
||||
contains []string // when the body output is not constant, we just check for snippets to be present in the response
|
||||
}{
|
||||
{
|
||||
name: "health",
|
||||
method: http.MethodGet,
|
||||
path: "/healthz",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "ready",
|
||||
method: http.MethodGet,
|
||||
path: "/ready",
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "discovery",
|
||||
method: http.MethodGet,
|
||||
path: oidc.DiscoveryEndpoint,
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
|
||||
},
|
||||
{
|
||||
name: "authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.AuthorizationEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"client_id": client.GetID(),
|
||||
"redirect_uri": "https://example.com",
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"response_type": string(oidc.ResponseTypeCode),
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of client/integration_test.go
|
||||
name: "code exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeCode),
|
||||
"client_id": client.GetID(),
|
||||
"client_secret": "secret",
|
||||
"redirect_uri": "https://example.com",
|
||||
"code": "123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
|
||||
},
|
||||
{
|
||||
name: "JWT authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeBearer),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"assertion": jwtProfileToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
|
||||
},
|
||||
{
|
||||
name: "Token exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeTokenExchange),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
"subject_token": jwtToken,
|
||||
"subject_token_type": string(oidc.AccessTokenType),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Client credentials exchange",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"sid1", "verysecret"},
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeClientCredentials),
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
|
||||
},
|
||||
{
|
||||
// This call will fail. A successfull test is already
|
||||
// part of device_test.go
|
||||
name: "device token",
|
||||
method: http.MethodPost,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
header: map[string]string{
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeDeviceCode),
|
||||
"device_code": "123",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
|
||||
},
|
||||
{
|
||||
name: "missing grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
|
||||
},
|
||||
{
|
||||
name: "unsupported grant type",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": "foo",
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
|
||||
},
|
||||
{
|
||||
name: "introspection",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.IntrospectionEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "user info",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.UserinfoEndpoint().Relative(),
|
||||
header: map[string]string{
|
||||
"authorization": "Bearer " + accessToken,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
|
||||
},
|
||||
{
|
||||
name: "refresh token",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.TokenEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"grant_type": string(oidc.GrantTypeRefreshToken),
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": client.GetID(),
|
||||
"client_secret": "secret",
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"access_token":"`,
|
||||
`","token_type":"Bearer","refresh_token":"`,
|
||||
`","expires_in":299,"id_token":"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "revoke",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.RevocationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"token": accessTokenRevoke,
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "end session",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.EndSessionEndpoint().Relative(),
|
||||
values: map[string]string{
|
||||
"id_token_hint": idToken,
|
||||
"client_id": "web",
|
||||
},
|
||||
wantCode: http.StatusFound,
|
||||
headerContains: map[string]string{"Location": "/logged-out"},
|
||||
contains: []string{`<a href="/logged-out">Found</a>.`},
|
||||
},
|
||||
{
|
||||
name: "keys",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.KeysEndpoint().Relative(),
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
|
||||
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device authorization",
|
||||
method: http.MethodGet,
|
||||
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
|
||||
basicAuth: &basicAuth{"web", "secret"},
|
||||
values: map[string]string{
|
||||
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
|
||||
},
|
||||
wantCode: http.StatusOK,
|
||||
contains: []string{
|
||||
`{"device_code":"`, `","user_code":"`,
|
||||
`","verification_uri":"https://localhost:9998/device"`,
|
||||
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
|
||||
`","expires_in":300,"interval":5}`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
u := gu.PtrCopy(serverURL)
|
||||
u.Path = tt.path
|
||||
if tt.values != nil {
|
||||
u.RawQuery = mapAsValues(tt.values)
|
||||
}
|
||||
var body io.Reader
|
||||
if tt.body != nil {
|
||||
body = strings.NewReader(mapAsValues(tt.body))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(tt.method, u.String(), body)
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if tt.basicAuth != nil {
|
||||
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
server.ServeHTTP(rec, req)
|
||||
|
||||
resp := rec.Result()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCode, resp.StatusCode)
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBodyString := string(respBody)
|
||||
t.Log(respBodyString)
|
||||
t.Log(resp.Header)
|
||||
|
||||
if tt.json != "" {
|
||||
assert.JSONEq(t, tt.json, respBodyString)
|
||||
}
|
||||
for _, c := range tt.contains {
|
||||
assert.Contains(t, respBodyString, c)
|
||||
}
|
||||
for k, v := range tt.headerContains {
|
||||
assert.Contains(t, resp.Header.Get(k), v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1333
pkg/op/server_http_test.go
Normal file
1333
pkg/op/server_http_test.go
Normal file
File diff suppressed because it is too large
Load diff
344
pkg/op/server_legacy.go
Normal file
344
pkg/op/server_legacy.go
Normal file
|
@ -0,0 +1,344 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// LegacyServer is an implementation of [Server[] that
|
||||
// simply wraps a [OpenIDProvider].
|
||||
// It can be used to transition from the former Provider/Storage
|
||||
// interfaces to the new Server interface.
|
||||
type LegacyServer struct {
|
||||
UnimplementedServer
|
||||
provider OpenIDProvider
|
||||
endpoints Endpoints
|
||||
}
|
||||
|
||||
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
||||
// the Server's router.
|
||||
//
|
||||
// Only non-nil endpoints will be registered on the router.
|
||||
// Nil endpoints are disabled.
|
||||
//
|
||||
// The passed endpoints is also set to the provider,
|
||||
// to be consistent with the discovery config.
|
||||
// Any `With*Endpoint()` option used on the provider is
|
||||
// therefore ineffective.
|
||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
||||
server := RegisterServer(&LegacyServer{
|
||||
provider: provider,
|
||||
endpoints: endpoints,
|
||||
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", server)
|
||||
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return NewResponse(Status{Status: "ok"}), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
for _, probe := range s.provider.Probes() {
|
||||
// shouldn't we run probes in Go routines?
|
||||
if err := probe(ctx); err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return NewResponse(Status{Status: "ok"}), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
return NewResponse(
|
||||
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
keys, err := s.provider.Storage().KeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(jsonWebKeySet(keys)), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
|
||||
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
|
||||
)
|
||||
|
||||
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||
if r.Data.RequestParam != "" {
|
||||
if !s.provider.RequestObjectSupported() {
|
||||
return nil, oidc.ErrRequestNotSupported()
|
||||
}
|
||||
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if r.Data.ClientID == "" {
|
||||
return nil, ErrAuthReqMissingClientID
|
||||
}
|
||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||
if err != nil {
|
||||
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
|
||||
}
|
||||
|
||||
return &ClientRequest[oidc.AuthRequest]{
|
||||
Request: r,
|
||||
Client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
|
||||
if err != nil {
|
||||
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
|
||||
}
|
||||
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
||||
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
|
||||
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
||||
if !ok {
|
||||
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
|
||||
}
|
||||
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
|
||||
}
|
||||
|
||||
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
|
||||
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
||||
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
|
||||
}
|
||||
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
|
||||
}
|
||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||
if err != nil {
|
||||
return nil, oidc.ErrInvalidClient().WithParent(err)
|
||||
}
|
||||
|
||||
switch client.AuthMethod() {
|
||||
case oidc.AuthMethodNone:
|
||||
return client, nil
|
||||
case oidc.AuthMethodPrivateKeyJWT:
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
|
||||
case oidc.AuthMethodPost:
|
||||
if !s.provider.AuthMethodPostSupported() {
|
||||
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
|
||||
}
|
||||
}
|
||||
|
||||
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
|
||||
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.Client.AuthMethod() == oidc.AuthMethodNone {
|
||||
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeRefreshTokenSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
||||
}
|
||||
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.Client.GetID() != request.GetClientID() {
|
||||
return nil, oidc.ErrInvalidGrant()
|
||||
}
|
||||
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
|
||||
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
||||
if !ok {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
|
||||
}
|
||||
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeTokenExchangeSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
|
||||
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
||||
if !ok {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
|
||||
}
|
||||
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeClientCredentialsSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
// use a limited context timeout shorter as the default
|
||||
// poll interval of 5 seconds.
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{r.Client.GetID()},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
|
||||
if !ok {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
|
||||
if err != nil {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
response.Active = true
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
||||
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
|
||||
if !ok {
|
||||
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
|
||||
}
|
||||
info := new(oidc.UserInfo)
|
||||
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusForbidden)
|
||||
}
|
||||
return NewResponse(info), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
||||
var subject string
|
||||
doDecrypt := true
|
||||
if r.Data.TokenTypeHint != "access_token" {
|
||||
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
|
||||
if err != nil {
|
||||
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
|
||||
if !errors.Is(err, ErrInvalidRefreshToken) {
|
||||
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
|
||||
}
|
||||
} else {
|
||||
r.Data.Token = tokenID
|
||||
subject = userID
|
||||
doDecrypt = false
|
||||
}
|
||||
}
|
||||
if doDecrypt {
|
||||
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
|
||||
if ok {
|
||||
r.Data.Token = tokenID
|
||||
subject = userID
|
||||
}
|
||||
}
|
||||
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
|
||||
return nil, RevocationError(err)
|
||||
}
|
||||
return NewResponse(nil), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
||||
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRedirect(session.RedirectURI), nil
|
||||
}
|
5
pkg/op/server_test.go
Normal file
5
pkg/op/server_test.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package op
|
||||
|
||||
// implementation check
|
||||
var _ Server = &UnimplementedServer{}
|
||||
var _ Server = &LegacyServer{}
|
|
@ -88,7 +88,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge())
|
||||
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
|
||||
return request, client, err
|
||||
}
|
||||
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||
|
|
|
@ -197,12 +197,6 @@ func ValidateTokenExchangeRequest(
|
|||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing")
|
||||
}
|
||||
|
||||
storage := exchanger.Storage()
|
||||
teStorage, ok := storage.(TokenExchangeStorage)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported")
|
||||
}
|
||||
|
||||
client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -220,10 +214,28 @@ func ValidateTokenExchangeRequest(
|
|||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported")
|
||||
}
|
||||
|
||||
req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return req, client, nil
|
||||
}
|
||||
|
||||
func CreateTokenExchangeRequest(
|
||||
ctx context.Context,
|
||||
oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
|
||||
client Client,
|
||||
exchanger Exchanger,
|
||||
) (TokenExchangeRequest, error) {
|
||||
teStorage, ok := exchanger.Storage().(TokenExchangeStorage)
|
||||
if !ok {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
||||
}
|
||||
|
||||
exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger,
|
||||
oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -234,7 +246,7 @@ func ValidateTokenExchangeRequest(
|
|||
exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger,
|
||||
oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true)
|
||||
if !ok {
|
||||
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,17 +270,17 @@ func ValidateTokenExchangeRequest(
|
|||
authTime: time.Now(),
|
||||
}
|
||||
|
||||
err = teStorage.ValidateTokenExchangeRequest(ctx, req)
|
||||
err := teStorage.ValidateTokenExchangeRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = teStorage.CreateTokenExchangeRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return req, client, nil
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func GetTokenIDAndSubjectFromToken(
|
||||
|
|
|
@ -117,11 +117,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string,
|
|||
|
||||
// AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
|
||||
// code_challenge of the auth request (PKCE)
|
||||
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error {
|
||||
if tokenReq.CodeVerifier == "" {
|
||||
func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error {
|
||||
if codeVerifier == "" {
|
||||
return oidc.ErrInvalidRequest().WithDescription("code_challenge required")
|
||||
}
|
||||
if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) {
|
||||
if !oidc.VerifyCodeChallenge(challenge, codeVerifier) {
|
||||
return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -131,6 +131,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
|
|||
}
|
||||
|
||||
func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
statusErr := RevocationError(err)
|
||||
httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode)
|
||||
}
|
||||
|
||||
func RevocationError(err error) StatusError {
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
status := http.StatusBadRequest
|
||||
switch e.ErrorType {
|
||||
|
@ -139,7 +144,7 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
|||
case oidc.ServerError:
|
||||
status = 500
|
||||
}
|
||||
httphelper.MarshalJSONWithStatus(w, e, status)
|
||||
return NewStatusError(e, status)
|
||||
}
|
||||
|
||||
func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue