feat(op): Server interface (#447)

* first draft of a new server interface

* allow any response type

* complete interface docs

* refelct the format from the proposal

* intermediate commit with some methods implemented

* implement remaining token grant type methods

* implement remaining server methods

* error handling

* rewrite auth request validation

* define handlers, routes

* input validation and concrete handlers

* check if client credential client is authenticated

* copy and modify the routes test for the legacy server

* run integration tests against both Server and Provider

* remove unuse ValidateAuthRequestV2 function

* unit tests for error handling

* cleanup tokenHandler

* move server routest test

* unit test authorize

* handle client credentials in VerifyClient

* change code exchange route test

* finish http unit tests

* review server interface docs and spelling

* add withClient unit test

* server options

* cleanup unused GrantType method

* resolve typo comments

* make endpoints pointers to enable/disable them

* jwt profile base work

* jwt: correct the test expect

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Tim Möhlmann 2023-09-28 17:30:08 +03:00 committed by GitHub
parent daf82a5e04
commit 0f8a0585bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 3654 additions and 126 deletions

View file

@ -40,7 +40,7 @@ var counter atomic.Int64
// SetupServer creates an OIDC server with Issuer=http://localhost:<port>
//
// Use one of the pre-made clients in storage/clients.go or register a new one.
func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router {
func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool) chi.Router {
// the OpenID Provider requires a 32-byte key for (token) encryption
// be sure to create a proper crypto random key and manage it securely!
key := sha256.Sum256([]byte("test"))
@ -77,12 +77,17 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger) chi.Router
registerDeviceAuth(storage, r)
})
handler := http.Handler(provider)
if wrapServer {
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
}
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
// is served on the correct path
//
// if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
// then you would have to set the path prefix (/custom/path/)
router.Mount("/", provider)
router.Mount("/", handler)
return router
}

View file

@ -27,7 +27,7 @@ func main() {
Level: slog.LevelDebug,
}),
)
router := exampleop.SetupServer(issuer, storage, logger)
router := exampleop.SetupServer(issuer, storage, logger, false)
server := &http.Server{
Addr: ":" + port,

View file

@ -185,7 +185,7 @@ func WebClient(id, secret string, redirectURIs ...string) *Client {
authMethod: oidc.AuthMethodBasic,
loginURL: defaultLoginURL,
responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
grantTypes: oidc.AllGrantTypes,
accessTokenType: op.AccessTokenTypeBearer,
devMode: false,
idTokenUserinfoClaimsAssertion: false,

View file

@ -3,6 +3,7 @@ package client_test
import (
"bytes"
"context"
"fmt"
"io"
"math/rand"
"net/http"
@ -50,6 +51,14 @@ func TestMain(m *testing.M) {
}
func TestRelyingPartySession(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testRelyingPartySession(t, wrapServer)
})
}
}
func testRelyingPartySession(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -57,7 +66,7 @@ func TestRelyingPartySession(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
@ -101,6 +110,14 @@ func TestRelyingPartySession(t *testing.T) {
}
func TestResourceServerTokenExchange(t *testing.T) {
for _, wrapServer := range []bool{false, true} {
t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
testResourceServerTokenExchange(t, wrapServer)
})
}
}
func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
t.Log("------- start example OP ------")
targetURL := "http://local-site"
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
@ -108,7 +125,7 @@ func TestResourceServerTokenExchange(t *testing.T) {
opServer := httptest.NewServer(&dh)
defer opServer.Close()
t.Logf("auth server at %s", opServer.URL)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger)
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)

View file

@ -74,7 +74,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
}
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
return
@ -130,31 +130,31 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A
// ParseRequestObject parse the `request` parameter, validates the token including the signature
// and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil {
return nil, err
return err
}
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest()
return oidc.ErrInvalidRequest()
}
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err
return err
}
CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil
return nil
}
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request

View file

@ -180,3 +180,10 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
}
return data.ClientID, false, nil
}
type ClientCredentials struct {
ClientID string `schema:"client_id"`
ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
ClientAssertion string `schema:"client_assertion"` // JWT
ClientAssertionType string `schema:"client_assertion_type"`
}

View file

@ -20,14 +20,14 @@ var (
type Configuration interface {
IssuerFromRequest(r *http.Request) string
Insecure() bool
AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint
UserinfoEndpoint() Endpoint
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
DeviceAuthorizationEndpoint() Endpoint
AuthorizationEndpoint() *Endpoint
TokenEndpoint() *Endpoint
IntrospectionEndpoint() *Endpoint
UserinfoEndpoint() *Endpoint
RevocationEndpoint() *Endpoint
EndSessionEndpoint() *Endpoint
KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() *Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool

View file

@ -63,41 +63,51 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt
}
func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return err
}
req, err := ParseDeviceCodeRequest(r, o)
if err != nil {
return err
}
response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
if err != nil {
return err
}
httphelper.MarshalJSON(w, response)
return nil
}
func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
storage, err := assertDeviceStorage(o.Storage())
if err != nil {
return nil, err
}
config := o.DeviceAuthorization()
deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
expires := time.Now().Add(config.Lifetime)
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
if err != nil {
return err
return nil, NewStatusError(err, http.StatusInternalServerError)
}
var verification *url.URL
if config.UserFormURL != "" {
if verification, err = url.Parse(config.UserFormURL); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
return nil, NewStatusError(err, http.StatusInternalServerError)
}
} else {
if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil {
return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
return nil, NewStatusError(err, http.StatusInternalServerError)
}
verification.Path = config.UserFormPath
}
@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
verification.RawQuery = "user_code=" + userCode
response.VerificationURIComplete = verification.String()
httphelper.MarshalJSON(w, response)
return nil
return response, nil
}
func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {

View file

@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(r, c, s))
Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
}
}
@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := config.IssuerFromRequest(r)
func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
@ -49,7 +49,38 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
issuer := IssuerFromContext(ctx)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
TokenEndpoint: endpoints.Token.Absolute(issuer),
IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
JwksURI: endpoints.JwksURI.Absolute(issuer),
DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),

View file

@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
request *http.Request
c op.Configuration
s op.DiscoverStorage
ctx context.Context
c op.Configuration
s op.DiscoverStorage
}
tests := []struct {
name string
@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
assert.Equal(t, tt.want, got)
})
}

View file

@ -1,32 +1,46 @@
package op
import "strings"
import (
"errors"
"strings"
)
type Endpoint struct {
path string
url string
}
func NewEndpoint(path string) Endpoint {
return Endpoint{path: path}
func NewEndpoint(path string) *Endpoint {
return &Endpoint{path: path}
}
func NewEndpointWithURL(path, url string) Endpoint {
return Endpoint{path: path, url: url}
func NewEndpointWithURL(path, url string) *Endpoint {
return &Endpoint{path: path, url: url}
}
func (e Endpoint) Relative() string {
func (e *Endpoint) Relative() string {
if e == nil {
return ""
}
return relativeEndpoint(e.path)
}
func (e Endpoint) Absolute(host string) string {
func (e *Endpoint) Absolute(host string) string {
if e == nil {
return ""
}
if e.url != "" {
return e.url
}
return absoluteEndpoint(host, e.path)
}
func (e Endpoint) Validate() error {
var ErrNilEndpoint = errors.New("nil endpoint")
func (e *Endpoint) Validate() error {
if e == nil {
return ErrNilEndpoint
}
return nil // TODO:
}

View file

@ -3,13 +3,14 @@ package op_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/op"
)
func TestEndpoint_Path(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
want string
}{
{
@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test",
},
{
"nil",
nil,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
}
tests := []struct {
name string
e op.Endpoint
e *op.Endpoint
args args
want string
}{
@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"},
"https://test.com/test",
},
{
"nil",
nil,
args{"https://host"},
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) {
func TestEndpoint_Validate(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
wantErr bool
e *op.Endpoint
wantErr error
}{
// TODO: Add test cases.
{
"nil",
nil,
op.ErrNilEndpoint,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
err := tt.e.Validate()
require.ErrorIs(t, err, tt.wantErr)
})
}
}

View file

@ -1,6 +1,9 @@
package op
import (
"context"
"errors"
"fmt"
"net/http"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
@ -66,3 +69,101 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, status)
}
// TryErrorRedirect tries to handle an error by redirecting a client.
// If this attempt fails, an error is returned that must be returned
// to the client instead.
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
e := oidc.DefaultToServerError(parent, parent.Error())
logger = logger.With("oidc_error", e)
if authReq == nil {
logger.Log(ctx, e.LogLevel(), "auth request")
return nil, AsStatusError(e, http.StatusBadRequest)
}
if logAuthReq, ok := authReq.(LogAuthRequest); ok {
logger = logger.With("auth_request", logAuthReq)
}
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
return nil, AsStatusError(e, http.StatusBadRequest)
}
e.State = authReq.GetState()
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil {
logger.ErrorContext(ctx, "auth response URL", "error", err)
return nil, AsStatusError(err, http.StatusBadRequest)
}
logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
return NewRedirect(url), nil
}
// StatusError wraps an error with a HTTP status code.
// The status code is passed to the handler's writer.
type StatusError struct {
parent error
statusCode int
}
// NewStatusError sets the parent and statusCode to a new StatusError.
// It is recommended for parent to be an [oidc.Error].
//
// Typically implementations should only use this to signal something
// very specific, like an internal server error.
// If a returned error is not a StatusError, the framework
// will set a statusCode based on what the standard specifies,
// which is [http.StatusBadRequest] for most of the time.
// If the error encountered can described clearly with a [oidc.Error],
// do not use this function, as it might break standard rules!
func NewStatusError(parent error, statusCode int) StatusError {
return StatusError{
parent: parent,
statusCode: statusCode,
}
}
// AsStatusError unwraps a StatusError from err
// and returns it unmodified if found.
// If no StatuError was found, a new one is returned
// with statusCode set to it as a default.
func AsStatusError(err error, statusCode int) (target StatusError) {
if errors.As(err, &target) {
return target
}
return NewStatusError(err, statusCode)
}
func (e StatusError) Error() string {
return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
}
func (e StatusError) Unwrap() error {
return e.parent
}
func (e StatusError) Is(err error) bool {
var target StatusError
if !errors.As(err, &target) {
return false
}
return errors.Is(e.parent, target.parent) &&
e.statusCode == target.statusCode
}
// WriteError asserts for a StatusError containing an [oidc.Error].
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
statusError := AsStatusError(err, http.StatusBadRequest)
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
}

View file

@ -1,9 +1,12 @@
package op
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
@ -275,3 +278,400 @@ func TestRequestError(t *testing.T) {
})
}
}
func TestTryErrorRedirect(t *testing.T) {
type args struct {
ctx context.Context
authReq ErrAuthRequest
parent error
}
tests := []struct {
name string
args args
want *Redirect
wantErr error
wantLog string
}{
{
name: "nil auth request",
args: args{
ctx: context.Background(),
authReq: nil,
parent: io.ErrClosedPipe,
},
wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
wantLog: `{
"level":"ERROR",
"msg":"auth request",
"time":"not",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
}
}`,
},
{
name: "auth request, no redirect URI",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request, redirect disabled",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
},
wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
wantLog: `{
"level":"WARN",
"msg":"auth request: not redirecting",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"oops",
"type":"invalid_request",
"redirect_disabled":true
}
}`,
},
{
name: "auth request, url parse error",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "can't parse this!\n",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
wantErr: func() error {
//lint:ignore SA1007 just recreating the error for testing
_, err := url.Parse("can't parse this!\n")
err = oidc.ErrServerError().WithParent(err)
return NewStatusError(err, http.StatusBadRequest)
}(),
wantLog: `{
"level":"ERROR",
"msg":"auth response URL",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"can't parse this!\n",
"response_type":"responseType",
"scopes":"a b"
},
"error":{
"type":"server_error",
"parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
}
}`,
},
{
name: "auth request redirect",
args: args{
ctx: context.Background(),
authReq: &oidc.AuthRequest{
Scopes: oidc.SpaceDelimitedArray{"a", "b"},
ResponseType: "responseType",
ClientID: "123",
RedirectURI: "http://example.com/callback",
State: "state1",
ResponseMode: oidc.ResponseModeQuery,
},
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
want: &Redirect{
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
},
wantLog: `{
"level":"WARN",
"msg":"auth request redirect",
"time":"not",
"auth_request":{
"client_id":"123",
"redirect_uri":"http://example.com/callback",
"response_type":"responseType",
"scopes":"a b"
},
"oidc_error":{
"description":"sign in",
"type":"interaction_required"
},
"url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
encoder := schema.NewEncoder()
got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
gotLog := logOut.String()
t.Log(gotLog)
assert.JSONEq(t, tt.wantLog, gotLog, "log output")
})
}
}
func TestNewStatusError(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
want := "Internal Server Error: io: read/write on closed pipe"
got := fmt.Sprint(err)
assert.Equal(t, want, got)
}
func TestAsStatusError(t *testing.T) {
type args struct {
err error
statusCode int
}
tests := []struct {
name string
args args
want string
}{
{
name: "already status error",
args: args{
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
statusCode: http.StatusBadRequest,
},
want: "Internal Server Error: io: read/write on closed pipe",
},
{
name: "oidc error",
args: args{
err: oidc.ErrAcrInvalid,
statusCode: http.StatusBadRequest,
},
want: "Bad Request: acr is invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := AsStatusError(tt.args.err, tt.args.statusCode)
got := fmt.Sprint(err)
assert.Equal(t, tt.want, got)
})
}
}
func TestStatusError_Unwrap(t *testing.T) {
err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestStatusError_Is(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil error",
args: args{err: nil},
want: false,
},
{
name: "other error",
args: args{err: io.EOF},
want: false,
},
{
name: "other parent",
args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
want: false,
},
{
name: "other status",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
want: false,
},
{
name: "same",
args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
want: true,
},
{
name: "wrapped",
args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
if got := e.Is(tt.args.err); got != tt.want {
t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
}
})
}
}
func TestWriteError(t *testing.T) {
tests := []struct {
name string
err error
wantStatus int
wantBody string
wantLog string
}{
{
name: "not a status or oidc error",
err: io.ErrClosedPipe,
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "status error w/o oidc",
err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
wantStatus: http.StatusInternalServerError,
wantBody: `{
"error":"server_error",
"error_description":"io: read/write on closed pipe"
}`,
wantLog: `{
"level":"ERROR",
"msg":"request error",
"oidc_error":{
"description":"io: read/write on closed pipe",
"parent":"io: read/write on closed pipe",
"type":"server_error"
},
"time":"not"
}`,
},
{
name: "oidc error w/o status",
err: oidc.ErrInvalidRequest().WithDescription("oops"),
wantStatus: http.StatusBadRequest,
wantBody: `{
"error":"invalid_request",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"invalid_request"
},
"time":"not"
}`,
},
{
name: "status with oidc error",
err: NewStatusError(
oidc.ErrUnauthorizedClient().WithDescription("oops"),
http.StatusUnauthorized,
),
wantStatus: http.StatusUnauthorized,
wantBody: `{
"error":"unauthorized_client",
"error_description":"oops"
}`,
wantLog: `{
"level":"WARN",
"msg":"request error",
"oidc_error":{
"description":"oops",
"type":"unauthorized_client"
},
"time":"not"
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logOut := new(strings.Builder)
logger := slog.New(
slog.NewJSONHandler(logOut, &slog.HandlerOptions{
Level: slog.LevelInfo,
}).WithAttrs([]slog.Attr{slog.String("time", "not")}),
)
r := httptest.NewRequest("GET", "/target", nil)
w := httptest.NewRecorder()
WriteError(w, r, tt.err, logger)
res := w.Result()
assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
gotBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
assert.JSONEq(t, tt.wantLog, logOut.String())
})
}
}

View file

@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom
}
// AuthorizationEndpoint mocks base method.
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -107,10 +107,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
}
// DeviceAuthorizationEndpoint mocks base method.
func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint {
func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -121,10 +121,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C
}
// EndSessionEndpoint mocks base method.
func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EndSessionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -233,10 +233,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup
}
// IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -275,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go
}
// KeysEndpoint mocks base method.
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -331,10 +331,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor
}
// RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -373,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
}
// TokenEndpoint mocks base method.
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
@ -401,10 +401,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint)
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}

View file

@ -32,7 +32,7 @@ const (
)
var (
DefaultEndpoints = &endpoints{
DefaultEndpoints = &Endpoints{
Authorization: NewEndpoint(defaultAuthorizationEndpoint),
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
@ -131,16 +131,17 @@ type Config struct {
DeviceAuthorization DeviceAuthorizationConfig
}
type endpoints struct {
Authorization Endpoint
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
DeviceAuthorization Endpoint
// Endpoints defines endpoint routes.
type Endpoints struct {
Authorization *Endpoint
Token *Endpoint
Introspection *Endpoint
Userinfo *Endpoint
Revocation *Endpoint
EndSession *Endpoint
CheckSessionIframe *Endpoint
JwksURI *Endpoint
DeviceAuthorization *Endpoint
}
// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
@ -212,7 +213,7 @@ type Provider struct {
config *Config
issuer IssuerFromRequest
insecure bool
endpoints *endpoints
endpoints *Endpoints
storage Storage
keySet *openIDKeySet
crypto Crypto
@ -233,35 +234,35 @@ func (o *Provider) Insecure() bool {
return o.insecure
}
func (o *Provider) AuthorizationEndpoint() Endpoint {
func (o *Provider) AuthorizationEndpoint() *Endpoint {
return o.endpoints.Authorization
}
func (o *Provider) TokenEndpoint() Endpoint {
func (o *Provider) TokenEndpoint() *Endpoint {
return o.endpoints.Token
}
func (o *Provider) IntrospectionEndpoint() Endpoint {
func (o *Provider) IntrospectionEndpoint() *Endpoint {
return o.endpoints.Introspection
}
func (o *Provider) UserinfoEndpoint() Endpoint {
func (o *Provider) UserinfoEndpoint() *Endpoint {
return o.endpoints.Userinfo
}
func (o *Provider) RevocationEndpoint() Endpoint {
func (o *Provider) RevocationEndpoint() *Endpoint {
return o.endpoints.Revocation
}
func (o *Provider) EndSessionEndpoint() Endpoint {
func (o *Provider) EndSessionEndpoint() *Endpoint {
return o.endpoints.EndSession
}
func (o *Provider) DeviceAuthorizationEndpoint() Endpoint {
func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) KeysEndpoint() Endpoint {
func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI
}
@ -420,7 +421,7 @@ func WithAllowInsecure() Option {
}
}
func WithCustomAuthEndpoint(endpoint Endpoint) Option {
func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -430,7 +431,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomTokenEndpoint(endpoint Endpoint) Option {
func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -440,7 +441,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -450,7 +451,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -460,7 +461,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -470,7 +471,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -480,7 +481,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomKeysEndpoint(endpoint Endpoint) Option {
func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -490,7 +491,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
return func(o *Provider) error {
if err := endpoint.Validate(); err != nil {
return err
@ -500,8 +501,16 @@ func WithCustomDeviceAuthorizationEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
// WithCustomEndpoints sets multiple endpoints at once.
// Non of the endpoints may be nil, or an error will
// be returned when the Option used by the Provider.
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
return func(o *Provider) error {
for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
if err := e.Validate(); err != nil {
return err
}
}
o.endpoints.Authorization = auth
o.endpoints.Token = token
o.endpoints.Userinfo = userInfo

View file

@ -395,3 +395,54 @@ func TestRoutes(t *testing.T) {
})
}
}
func TestWithCustomEndpoints(t *testing.T) {
type args struct {
auth *op.Endpoint
token *op.Endpoint
userInfo *op.Endpoint
revocation *op.Endpoint
endSession *op.Endpoint
keys *op.Endpoint
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "all nil",
args: args{},
wantErr: op.ErrNilEndpoint,
},
{
name: "all set",
args: args{
auth: op.NewEndpoint("/authorize"),
token: op.NewEndpoint("/oauth/token"),
userInfo: op.NewEndpoint("/userinfo"),
revocation: op.NewEndpoint("/revoke"),
endSession: op.NewEndpoint("/end_session"),
keys: op.NewEndpoint("/keys"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
storage.NewStorage(storage.NewUserStore(testIssuer)),
op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr != nil {
return
}
assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
assert.Equal(t, tt.args.token, provider.TokenEndpoint())
assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
})
}
}

View file

@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
}
func ok(w http.ResponseWriter) {
httphelper.MarshalJSON(w, status{"ok"})
httphelper.MarshalJSON(w, Status{"ok"})
}
type status struct {
type Status struct {
Status string `json:"status,omitempty"`
}

346
pkg/op/server.go Normal file
View file

@ -0,0 +1,346 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/muhlemmer/gu"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// Server describes the interface that needs to be implemented to serve
// OpenID Connect and Oauth2 standard requests.
//
// Methods are called after the HTTP route is resolved and
// the request body is parsed into the Request's Data field.
// When a method is called, it can be assumed that required fields,
// as described in their relevant standard, are validated already.
// The Response Data field may be of any type to allow flexibility
// to extend responses with custom fields. There are however requirements
// in the standards regarding the response models. Where applicable
// the method documentation gives a recommended type which can be used
// directly or extended upon.
//
// The addition of new methods is not considered a breaking change
// as defined by semver rules.
// Implementations MUST embed [UnimplementedServer] to maintain
// forward compatibility.
//
// EXPERIMENTAL: may change until v4
type Server interface {
// Health returns a status of "ok" once the Server is listening.
// The recommended Response Data type is [Status].
Health(context.Context, *Request[struct{}]) (*Response, error)
// Ready returns a status of "ok" once all dependencies,
// such as database storage, are ready.
// An error can be returned to explain what is not ready.
// The recommended Response Data type is [Status].
Ready(context.Context, *Request[struct{}]) (*Response, error)
// Discovery returns the OpenID Provider Configuration Information for this server.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
// The recommended Response Data type is [oidc.DiscoveryConfiguration].
Discovery(context.Context, *Request[struct{}]) (*Response, error)
// Keys serves the JWK set which the client can use verify signatures from the op.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
// The recommended Response Data type is [jose.JSONWebKeySet].
Keys(context.Context, *Request[struct{}]) (*Response, error)
// VerifyAuthRequest verifies the Auth Request and
// adds the Client to the request.
//
// When the `request` field is populated with a
// "Request Object" JWT, it needs to be Validated
// and its claims overwrite any fields in the AuthRequest.
// If the implementation does not support "Request Object",
// it MUST return an [oidc.ErrRequestNotSupported].
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
// Authorize initiates the authorization flow and redirects to a login page.
// See the various https://openid.net/specs/openid-connect-core-1_0.html
// authorize endpoint sections (one for each type of flow).
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
// DeviceAuthorization initiates the device authorization flow.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
// The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
// VerifyClient is called on most oauth/token handlers to authenticate,
// using either a secret (POST, Basic) or assertion (JWT).
// If no secrets are provided, the client must be public.
// This method is called before each method that takes a
// [ClientRequest] argument.
VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
// CodeExchange returns Tokens after an authorization code
// is obtained in a successful Authorize flow.
// It is called by the Token endpoint handler when
// grant_type has the value authorization_code
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
// The recommended Response Data type is [oidc.AccessTokenResponse].
CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
// RefreshToken returns new Tokens after verifying a Refresh token.
// It is called by the Token endpoint handler when
// grant_type has the value refresh_token
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
// The recommended Response Data type is [oidc.AccessTokenResponse].
RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
// https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
// The recommended Response Data type is [oidc.AccessTokenResponse].
JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
// TokenExchange handles the OAuth 2.0 token exchange grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
// https://datatracker.ietf.org/doc/html/rfc8693
// The recommended Response Data type is [oidc.AccessTokenResponse].
TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
// ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
// It is called by the Token endpoint handler when
// grant_type has the value client_credentials
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
// The recommended Response Data type is [oidc.AccessTokenResponse].
ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
// DeviceToken handles the OAuth 2.0 Device Authorization Grant
// It is called by the Token endpoint handler when
// grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
// It is typically called in a polling fashion and appropriate errors
// should be returned to signal authorization_pending or access_denied etc.
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
// The recommended Response Data type is [oidc.AccessTokenResponse].
DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
// https://datatracker.ietf.org/doc/html/rfc7662
// The recommended Response Data type is [oidc.IntrospectionResponse].
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
// The recommended Response Data type is [oidc.UserInfo].
UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
// Revocation handles token revocation using an access or refresh token.
// https://datatracker.ietf.org/doc/html/rfc7009
// There are no response requirements. Data may remain empty.
Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
// EndSession handles the OpenID Connect RP-Initiated Logout.
// https://openid.net/specs/openid-connect-rpinitiated-1_0.html
// There are no response requirements. Data may remain empty.
EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
// mustImpl forces implementations to embed the UnimplementedServer for forward
// compatibility with the interface.
mustImpl()
}
// Request contains the [http.Request] informational fields
// and parsed Data from the request body (POST) or URL parameters (GET).
// Data can be assumed to be validated according to the applicable
// standard for the specific endpoints.
//
// EXPERIMENTAL: may change until v4
type Request[T any] struct {
Method string
URL *url.URL
Header http.Header
Form url.Values
PostForm url.Values
Data *T
}
func (r *Request[_]) path() string {
return r.URL.Path
}
func newRequest[T any](r *http.Request, data *T) *Request[T] {
return &Request[T]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
PostForm: r.PostForm,
Data: data,
}
}
// ClientRequest is a Request with a verified client attached to it.
// Methods that receive this argument may assume the client was authenticated,
// or verified to be a public client.
//
// EXPERIMENTAL: may change until v4
type ClientRequest[T any] struct {
*Request[T]
Client Client
}
func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
return &ClientRequest[T]{
Request: newRequest[T](r, data),
Client: client,
}
}
// Response object for most [Server] methods.
//
// EXPERIMENTAL: may change until v4
type Response struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
// Data will be JSON marshaled to
// the response body.
// We allow any type, so that implementations
// can extend the standard types as they wish.
// However, each method will recommend which
// (base) type to use as model, in order to
// be compliant with the standards.
Data any
}
// NewResponse creates a new response for data,
// without custom headers.
func NewResponse(data any) *Response {
return &Response{
Data: data,
}
}
func (resp *Response) writeOut(w http.ResponseWriter) {
gu.MapMerge(resp.Header, w.Header())
httphelper.MarshalJSON(w, resp.Data)
}
// Redirect is a special response type which will
// initiate a [http.StatusFound] redirect.
// The Params field will be encoded and set to the
// URL's RawQuery field before building the URL.
//
// EXPERIMENTAL: may change until v4
type Redirect struct {
// Header map will be merged with the
// header on the [http.ResponseWriter].
Header http.Header
URL string
}
func NewRedirect(url string) *Redirect {
return &Redirect{URL: url}
}
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
gu.MapMerge(r.Header, w.Header())
http.Redirect(w, r, red.URL, http.StatusFound)
}
type UnimplementedServer struct{}
// UnimplementedStatusCode is the status code returned for methods
// that are not yet implemented.
// Note that this means methods in the sense of the Go interface,
// and not http methods covered by "501 Not Implemented".
var UnimplementedStatusCode = http.StatusNotFound
func unimplementedError(r interface{ path() string }) StatusError {
err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
return NewStatusError(err, UnimplementedStatusCode)
}
func unimplementedGrantError(gt oidc.GrantType) StatusError {
err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
}
func (UnimplementedServer) mustImpl() {}
func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
return nil, oidc.ErrRequestNotSupported()
}
return nil, unimplementedError(r)
}
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeCode)
}
func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
return nil, unimplementedError(r)
}
func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
return nil, unimplementedError(r)
}

480
pkg/op/server_http.go Normal file
View file

@ -0,0 +1,480 @@
package op
import (
"context"
"net/http"
"net/url"
"github.com/go-chi/chi"
"github.com/rs/cors"
"github.com/zitadel/logging"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/schema"
"golang.org/x/exp/slog"
)
// RegisterServer registers an implementation of Server.
// The resulting handler takes care of routing and request parsing,
// with some basic validation of required fields.
// The routes can be customized with [WithEndpoints].
//
// EXPERIMENTAL: may change until v4
func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(true)
ws := &webServer{
server: server,
endpoints: endpoints,
decoder: decoder,
logger: slog.Default(),
}
for _, option := range options {
option(ws)
}
ws.createRouter()
return ws
}
type ServerOption func(s *webServer)
// WithHTTPMiddleware sets the passed middleware chain to the root of
// the Server's router.
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
return func(s *webServer) {
s.middleware = m
}
}
// WithDecoder overrides the default decoder,
// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
func WithDecoder(decoder httphelper.Decoder) ServerOption {
return func(s *webServer) {
s.decoder = decoder
}
}
// WithFallbackLogger overrides the fallback logger, which
// is used when no logger was found in the context.
// Defaults to [slog.Default].
func WithFallbackLogger(logger *slog.Logger) ServerOption {
return func(s *webServer) {
s.logger = logger
}
}
type webServer struct {
http.Handler
server Server
middleware []func(http.Handler) http.Handler
endpoints Endpoints
decoder httphelper.Decoder
logger *slog.Logger
}
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
if logger, ok := logging.FromContext(ctx); ok {
return logger
}
return s.logger
}
func (s *webServer) createRouter() {
router := chi.NewRouter()
router.Use(cors.New(defaultCORSOptions).Handler)
router.Use(s.middleware...)
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
s.Handler = router
}
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
if e != nil {
router.HandleFunc(e.Relative(), hf)
s.logger.Info("registered route", "endpoint", e.Relative())
}
}
type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
client, err := s.verifyRequestClient(r)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
if !ValidateGrantType(client, grantType) {
WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
return
}
}
handler(w, r, client)
}
}
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
if err = r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
cc := new(ClientCredentials)
if err = s.decoder.Decode(cc, r.Form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
// Basic auth takes precedence, so if set it overwrites the form data.
if clientID, clientSecret, ok := r.BasicAuth(); ok {
cc.ClientID, err = url.QueryUnescape(clientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
cc.ClientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
}
if cc.ClientID == "" && cc.ClientAssertion == "" {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
}
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
}
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
Method: r.Method,
URL: r.URL,
Header: r.Header,
Form: r.Form,
Data: cc,
})
}
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect, err := s.authorize(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
redirect.writeOut(w, r)
}
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
cr, err := s.server.VerifyAuthRequest(ctx, r)
if err != nil {
return nil, err
}
authReq := cr.Data
if authReq.RedirectURI == "" {
return nil, ErrAuthReqMissingRedirectURI
}
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return nil, err
}
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
if err != nil {
return nil, err
}
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return nil, err
}
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
return nil, err
}
return s.server.Authorize(ctx, cr)
}
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
case oidc.GrantTypeCode:
s.withClient(s.codeExchangeHandler)(w, r)
case oidc.GrantTypeRefreshToken:
s.withClient(s.refreshTokenHandler)(w, r)
case oidc.GrantTypeClientCredentials:
s.withClient(s.clientCredentialsHandler)(w, r)
case oidc.GrantTypeBearer:
s.jwtProfileHandler(w, r)
case oidc.GrantTypeTokenExchange:
s.withClient(s.tokenExchangeHandler)(w, r)
case oidc.GrantTypeDeviceCode:
s.withClient(s.deviceTokenHandler)(w, r)
case "":
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
default:
WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
}
}
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Assertion == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Code == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
return
}
if request.RedirectURI == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.RefreshToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.SubjectToken == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
return
}
if request.SubjectTokenType == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
return
}
if !request.SubjectTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
return
}
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
return
}
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.DeviceCode == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
if client.AuthMethod() == oidc.AuthMethodNone {
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
return
}
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if token, err := getAccessToken(r); err == nil {
request.AccessToken = token
}
if request.AccessToken == "" {
err = NewStatusError(
oidc.ErrInvalidRequest().WithDescription("access token missing"),
http.StatusUnauthorized,
)
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
if request.Token == "" {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
return
}
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w, r)
}
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
return
}
resp, err := method(r.Context(), newRequest(r, &struct{}{}))
if err != nil {
WriteError(w, r, err, s.getLogger(r.Context()))
return
}
resp.writeOut(w)
}
}
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
dst := new(R)
if err := r.ParseForm(); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
form := r.Form
if postOnly {
form = r.PostForm
}
if err := decoder.Decode(dst, form); err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return dst, nil
}

View file

@ -0,0 +1,345 @@
package op_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
)
func jwtProfile() (string, error) {
keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
if err != nil {
return "", err
}
signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
if err != nil {
return "", err
}
return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
}
func TestServerRoutes(t *testing.T) {
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
client, err := storage.GetClientByClientID(ctx, "web")
require.NoError(t, err)
oidcAuthReq := &oidc.AuthRequest{
ClientID: client.GetID(),
RedirectURI: "https://example.com",
MaxAge: gu.Ptr[uint](300),
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
ResponseType: oidc.ResponseTypeCode,
}
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
require.NoError(t, err)
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
require.NoError(t, err)
jwtProfileToken, err := jwtProfile()
require.NoError(t, err)
oidcAuthReq.IDTokenHint = idToken
serverURL, err := url.Parse(testIssuer)
require.NoError(t, err)
type basicAuth struct {
username, password string
}
tests := []struct {
name string
method string
path string
basicAuth *basicAuth
header map[string]string
values map[string]string
body map[string]string
wantCode int
headerContains map[string]string
json string // test for exact json output
contains []string // when the body output is not constant, we just check for snippets to be present in the response
}{
{
name: "health",
method: http.MethodGet,
path: "/healthz",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "ready",
method: http.MethodGet,
path: "/ready",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "discovery",
method: http.MethodGet,
path: oidc.DiscoveryEndpoint,
wantCode: http.StatusOK,
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
},
{
name: "authorization",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative(),
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
},
{
// This call will fail. A successfull test is already
// part of client/integration_test.go
name: "code exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeCode),
"client_id": client.GetID(),
"client_secret": "secret",
"redirect_uri": "https://example.com",
"code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
},
{
name: "JWT authorization",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"assertion": jwtProfileToken,
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
},
{
name: "Token exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
},
},
{
name: "Client credentials exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
},
{
// This call will fail. A successfull test is already
// part of device_test.go
name: "device token",
method: http.MethodPost,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
header: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
body: map[string]string{
"grant_type": string(oidc.GrantTypeDeviceCode),
"device_code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
},
{
name: "missing grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
},
{
name: "unsupported grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": "foo",
},
wantCode: http.StatusBadRequest,
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
},
{
name: "introspection",
method: http.MethodGet,
path: testProvider.IntrospectionEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessToken,
},
wantCode: http.StatusOK,
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "user info",
method: http.MethodGet,
path: testProvider.UserinfoEndpoint().Relative(),
header: map[string]string{
"authorization": "Bearer " + accessToken,
},
wantCode: http.StatusOK,
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "refresh token",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeRefreshToken),
"refresh_token": refreshToken,
"client_id": client.GetID(),
"client_secret": "secret",
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","token_type":"Bearer","refresh_token":"`,
`","expires_in":299,"id_token":"`,
},
},
{
name: "revoke",
method: http.MethodGet,
path: testProvider.RevocationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessTokenRevoke,
},
wantCode: http.StatusOK,
},
{
name: "end session",
method: http.MethodGet,
path: testProvider.EndSessionEndpoint().Relative(),
values: map[string]string{
"id_token_hint": idToken,
"client_id": "web",
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/logged-out"},
contains: []string{`<a href="/logged-out">Found</a>.`},
},
{
name: "keys",
method: http.MethodGet,
path: testProvider.KeysEndpoint().Relative(),
wantCode: http.StatusOK,
contains: []string{
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
},
},
{
name: "device authorization",
method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{
`{"device_code":"`, `","user_code":"`,
`","verification_uri":"https://localhost:9998/device"`,
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
`","expires_in":300,"interval":5}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := gu.PtrCopy(serverURL)
u.Path = tt.path
if tt.values != nil {
u.RawQuery = mapAsValues(tt.values)
}
var body io.Reader
if tt.body != nil {
body = strings.NewReader(mapAsValues(tt.body))
}
req := httptest.NewRequest(tt.method, u.String(), body)
for k, v := range tt.header {
req.Header.Set(k, v)
}
if tt.basicAuth != nil {
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
rec := httptest.NewRecorder()
server.ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
assert.Equal(t, tt.wantCode, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respBodyString := string(respBody)
t.Log(respBodyString)
t.Log(resp.Header)
if tt.json != "" {
assert.JSONEq(t, tt.json, respBodyString)
}
for _, c := range tt.contains {
assert.Contains(t, respBodyString, c)
}
for k, v := range tt.headerContains {
assert.Contains(t, resp.Header.Get(k), v)
}
})
}
}

1333
pkg/op/server_http_test.go Normal file

File diff suppressed because it is too large Load diff

344
pkg/op/server_legacy.go Normal file
View file

@ -0,0 +1,344 @@
package op
import (
"context"
"errors"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
// LegacyServer is an implementation of [Server[] that
// simply wraps a [OpenIDProvider].
// It can be used to transition from the former Provider/Storage
// interfaces to the new Server interface.
type LegacyServer struct {
UnimplementedServer
provider OpenIDProvider
endpoints Endpoints
}
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
// the Server's router.
//
// Only non-nil endpoints will be registered on the router.
// Nil endpoints are disabled.
//
// The passed endpoints is also set to the provider,
// to be consistent with the discovery config.
// Any `With*Endpoint()` option used on the provider is
// therefore ineffective.
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
server := RegisterServer(&LegacyServer{
provider: provider,
endpoints: endpoints,
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
router := chi.NewRouter()
router.Mount("/", server)
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
return router
}
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
for _, probe := range s.provider.Probes() {
// shouldn't we run probes in Go routines?
if err := probe(ctx); err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
}
return NewResponse(Status{Status: "ok"}), nil
}
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
return NewResponse(
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
), nil
}
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
keys, err := s.provider.Storage().KeySet(ctx)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(jsonWebKeySet(keys)), nil
}
var (
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
)
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
if r.Data.RequestParam != "" {
if !s.provider.RequestObjectSupported() {
return nil, oidc.ErrRequestNotSupported()
}
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
if err != nil {
return nil, err
}
}
if r.Data.ClientID == "" {
return nil, ErrAuthReqMissingClientID
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
return &ClientRequest[oidc.AuthRequest]{
Request: r,
Client: client,
}, nil
}
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
if err != nil {
return nil, err
}
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
if err != nil {
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
}
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
}
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
if err != nil {
return nil, NewStatusError(err, http.StatusInternalServerError)
}
return NewResponse(response), nil
}
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
}
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
}
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
}
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
if err != nil {
return nil, oidc.ErrInvalidClient().WithParent(err)
}
switch client.AuthMethod() {
case oidc.AuthMethodNone:
return client, nil
case oidc.AuthMethodPrivateKeyJWT:
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
case oidc.AuthMethodPost:
if !s.provider.AuthMethodPostSupported() {
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
}
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
if err != nil {
return nil, err
}
return client, nil
}
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
if err != nil {
return nil, err
}
if r.Client.AuthMethod() == oidc.AuthMethodNone {
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
return nil, err
}
}
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeRefreshTokenSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
}
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
if err != nil {
return nil, err
}
if r.Client.GetID() != request.GetClientID() {
return nil, oidc.ErrInvalidGrant()
}
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
return nil, err
}
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
}
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
if err != nil {
return nil, err
}
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
if !s.provider.GrantTypeTokenExchangeSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
if err != nil {
return nil, err
}
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
}
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
if err != nil {
return nil, err
}
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeClientCredentialsSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
if err != nil {
return nil, err
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{r.Client.GetID()},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
}
return NewResponse(resp), nil
}
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
response := new(oidc.IntrospectionResponse)
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
if !ok {
return NewResponse(response), nil
}
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
if err != nil {
return NewResponse(response), nil
}
response.Active = true
return NewResponse(response), nil
}
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
if !ok {
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
}
info := new(oidc.UserInfo)
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
return nil, NewStatusError(err, http.StatusForbidden)
}
return NewResponse(info), nil
}
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
var subject string
doDecrypt := true
if r.Data.TokenTypeHint != "access_token" {
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
if err != nil {
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
if !errors.Is(err, ErrInvalidRefreshToken) {
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
}
} else {
r.Data.Token = tokenID
subject = userID
doDecrypt = false
}
}
if doDecrypt {
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
if ok {
r.Data.Token = tokenID
subject = userID
}
}
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
return nil, RevocationError(err)
}
return NewResponse(nil), nil
}
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
if err != nil {
return nil, err
}
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
if err != nil {
return nil, err
}
return NewRedirect(session.RedirectURI), nil
}

5
pkg/op/server_test.go Normal file
View file

@ -0,0 +1,5 @@
package op
// implementation check
var _ Server = &UnimplementedServer{}
var _ Server = &LegacyServer{}

View file

@ -88,7 +88,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if err != nil {
return nil, nil, err
}
err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge())
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
return request, client, err
}
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {

View file

@ -197,12 +197,6 @@ func ValidateTokenExchangeRequest(
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing")
}
storage := exchanger.Storage()
teStorage, ok := storage.(TokenExchangeStorage)
if !ok {
return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported")
}
client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger)
if err != nil {
return nil, nil, err
@ -220,10 +214,28 @@ func ValidateTokenExchangeRequest(
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported")
}
req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger)
if err != nil {
return nil, nil, err
}
return req, client, nil
}
func CreateTokenExchangeRequest(
ctx context.Context,
oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
client Client,
exchanger Exchanger,
) (TokenExchangeRequest, error) {
teStorage, ok := exchanger.Storage().(TokenExchangeStorage)
if !ok {
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
}
exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger,
oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false)
if !ok {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
}
var (
@ -234,7 +246,7 @@ func ValidateTokenExchangeRequest(
exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger,
oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true)
if !ok {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
}
}
@ -258,17 +270,17 @@ func ValidateTokenExchangeRequest(
authTime: time.Now(),
}
err = teStorage.ValidateTokenExchangeRequest(ctx, req)
err := teStorage.ValidateTokenExchangeRequest(ctx, req)
if err != nil {
return nil, nil, err
return nil, err
}
err = teStorage.CreateTokenExchangeRequest(ctx, req)
if err != nil {
return nil, nil, err
return nil, err
}
return req, client, nil
return req, nil
}
func GetTokenIDAndSubjectFromToken(

View file

@ -117,11 +117,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string,
// AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
// code_challenge of the auth request (PKCE)
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error {
if tokenReq.CodeVerifier == "" {
func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error {
if codeVerifier == "" {
return oidc.ErrInvalidRequest().WithDescription("code_challenge required")
}
if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) {
if !oidc.VerifyCodeChallenge(challenge, codeVerifier) {
return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
}
return nil

View file

@ -131,6 +131,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
}
func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
statusErr := RevocationError(err)
httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode)
}
func RevocationError(err error) StatusError {
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
switch e.ErrorType {
@ -139,7 +144,7 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
case oidc.ServerError:
status = 500
}
httphelper.MarshalJSONWithStatus(w, e, status)
return NewStatusError(e, status)
}
func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {