457 lines
16 KiB
Go
457 lines
16 KiB
Go
package op
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
|
|
"github.com/go-chi/chi/v5"
|
|
)
|
|
|
|
// ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
|
|
// so that its methods can be individually overridden.
|
|
//
|
|
// EXPERIMENTAL: may change until v4
|
|
type ExtendedLegacyServer interface {
|
|
Server
|
|
Provider() OpenIDProvider
|
|
Endpoints() Endpoints
|
|
AuthCallbackURL() func(context.Context, string) string
|
|
}
|
|
|
|
// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
|
|
// It takes care of registering the IssuerFromRequest middleware.
|
|
// The authorizeCallbackHandler is registered on `/callback` under the authorization endpoint.
|
|
// Neither are part of the bare [Server] interface.
|
|
//
|
|
// EXPERIMENTAL: may change until v4
|
|
func RegisterLegacyServer(s ExtendedLegacyServer, authorizeCallbackHandler http.HandlerFunc, options ...ServerOption) http.Handler {
|
|
options = append(options,
|
|
WithHTTPMiddleware(intercept(s.Provider().IssuerFromRequest)),
|
|
WithSetRouter(func(r chi.Router) {
|
|
r.HandleFunc(s.Endpoints().Authorization.Relative()+authCallbackPathSuffix, authorizeCallbackHandler)
|
|
}),
|
|
)
|
|
return RegisterServer(s, s.Endpoints(), options...)
|
|
}
|
|
|
|
// LegacyServer is an implementation of [Server] that
|
|
// simply wraps an [OpenIDProvider].
|
|
// It can be used to transition from the former Provider/Storage
|
|
// interfaces to the new Server interface.
|
|
//
|
|
// EXPERIMENTAL: may change until v4
|
|
type LegacyServer struct {
|
|
UnimplementedServer
|
|
provider OpenIDProvider
|
|
endpoints Endpoints
|
|
}
|
|
|
|
// NewLegacyServer wraps provider in a `Server` implementation
|
|
//
|
|
// Only non-nil endpoints will be registered on the router.
|
|
// Nil endpoints are disabled.
|
|
//
|
|
// The passed endpoints is also used for the discovery config,
|
|
// and endpoints already set to the provider are ignored.
|
|
// Any `With*Endpoint()` option used on the provider is
|
|
// therefore ineffective.
|
|
//
|
|
// EXPERIMENTAL: may change until v4
|
|
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
|
|
return &LegacyServer{
|
|
provider: provider,
|
|
endpoints: endpoints,
|
|
}
|
|
}
|
|
|
|
func (s *LegacyServer) Provider() OpenIDProvider {
|
|
return s.provider
|
|
}
|
|
|
|
func (s *LegacyServer) Endpoints() Endpoints {
|
|
return s.endpoints
|
|
}
|
|
|
|
// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
|
|
func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string {
|
|
return func(ctx context.Context, requestID string) string {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.AuthCallbackURL")
|
|
defer span.End()
|
|
|
|
return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
|
|
}
|
|
}
|
|
|
|
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
|
return NewResponse(Status{Status: "ok"}), nil
|
|
}
|
|
|
|
func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
|
for _, probe := range s.provider.Probes() {
|
|
// shouldn't we run probes in Go routines?
|
|
if err := probe(ctx); err != nil {
|
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
|
}
|
|
}
|
|
return NewResponse(Status{Status: "ok"}), nil
|
|
}
|
|
|
|
func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.Discovery")
|
|
defer span.End()
|
|
|
|
return NewResponse(
|
|
createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
|
|
), nil
|
|
}
|
|
|
|
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.Keys")
|
|
defer span.End()
|
|
|
|
keys, err := s.provider.Storage().KeySet(ctx)
|
|
if err != nil {
|
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
|
}
|
|
return NewResponse(jsonWebKeySet(keys)), nil
|
|
}
|
|
|
|
var (
|
|
ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
|
|
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
|
|
)
|
|
|
|
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.VerifyAuthRequest")
|
|
defer span.End()
|
|
|
|
if r.Data.RequestParam != "" {
|
|
if !s.provider.RequestObjectSupported() {
|
|
return nil, oidc.ErrRequestNotSupported()
|
|
}
|
|
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if r.Data.ClientID == "" {
|
|
return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error())
|
|
}
|
|
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
|
if err != nil {
|
|
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
|
|
}
|
|
|
|
return &ClientRequest[oidc.AuthRequest]{
|
|
Request: r,
|
|
Client: client,
|
|
}, nil
|
|
}
|
|
|
|
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.Authorize")
|
|
defer span.End()
|
|
|
|
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
|
|
if err != nil {
|
|
return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
|
|
}
|
|
return NewRedirect(r.Client.LoginURL(req.GetID())), nil
|
|
}
|
|
|
|
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.DeviceAuthorization")
|
|
defer span.End()
|
|
|
|
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
|
if err != nil {
|
|
return nil, AsStatusError(err, http.StatusInternalServerError)
|
|
}
|
|
return NewResponse(response), nil
|
|
}
|
|
|
|
func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.VerifyClient")
|
|
defer span.End()
|
|
|
|
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
|
|
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
|
if !ok {
|
|
return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
|
|
}
|
|
return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
|
|
}
|
|
|
|
if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
|
|
jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
|
if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
|
|
return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
|
|
}
|
|
return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
|
|
}
|
|
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
|
if err != nil {
|
|
return nil, oidc.ErrInvalidClient().WithParent(err)
|
|
}
|
|
|
|
switch client.AuthMethod() {
|
|
case oidc.AuthMethodNone:
|
|
return client, nil
|
|
case oidc.AuthMethodPrivateKeyJWT:
|
|
return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
|
|
case oidc.AuthMethodPost:
|
|
if !s.provider.AuthMethodPostSupported() {
|
|
return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
|
|
}
|
|
}
|
|
|
|
err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.CodeExchange")
|
|
defer span.End()
|
|
|
|
authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if r.Client.AuthMethod() == oidc.AuthMethodNone || r.Data.CodeVerifier != "" {
|
|
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if r.Data.RedirectURI != authReq.GetRedirectURI() {
|
|
return nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
|
}
|
|
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.RefreshToken")
|
|
defer span.End()
|
|
|
|
if !s.provider.GrantTypeRefreshTokenSupported() {
|
|
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
|
}
|
|
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if r.Client.GetID() != request.GetClientID() {
|
|
return nil, oidc.ErrInvalidGrant()
|
|
}
|
|
if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.JWTProfile")
|
|
defer span.End()
|
|
|
|
exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
|
|
if !ok {
|
|
return nil, unimplementedGrantError(oidc.GrantTypeBearer)
|
|
}
|
|
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
|
if err != nil {
|
|
return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid")
|
|
}
|
|
|
|
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.TokenExchange")
|
|
defer span.End()
|
|
|
|
if !s.provider.GrantTypeTokenExchangeSupported() {
|
|
return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
|
|
}
|
|
tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.ClientCredentialsExchange")
|
|
defer span.End()
|
|
|
|
storage, ok := s.provider.Storage().(ClientCredentialsStorage)
|
|
if !ok {
|
|
return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
|
|
}
|
|
tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.DeviceToken")
|
|
defer span.End()
|
|
|
|
if !s.provider.GrantTypeDeviceCodeSupported() {
|
|
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
|
}
|
|
// use a limited context timeout shorter as the default
|
|
// poll interval of 5 seconds.
|
|
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
|
defer cancel()
|
|
|
|
tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewResponse(resp), nil
|
|
}
|
|
|
|
func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.authenticateResourceClient")
|
|
defer span.End()
|
|
|
|
if cc.ClientAssertion != "" {
|
|
if jp, ok := s.provider.(ClientJWTProfile); ok {
|
|
return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp)
|
|
}
|
|
return "", oidc.ErrInvalidClient().WithDescription("client_assertion not supported")
|
|
}
|
|
if err := s.provider.Storage().AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
|
|
return "", oidc.ErrUnauthorizedClient().WithParent(err)
|
|
}
|
|
return cc.ClientID, nil
|
|
}
|
|
|
|
func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.Introspect")
|
|
defer span.End()
|
|
|
|
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
response := new(oidc.IntrospectionResponse)
|
|
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
|
|
if !ok {
|
|
return NewResponse(response), nil
|
|
}
|
|
err = s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
|
|
if err != nil {
|
|
return NewResponse(response), nil
|
|
}
|
|
response.Active = true
|
|
return NewResponse(response), nil
|
|
}
|
|
|
|
func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.UserInfo")
|
|
defer span.End()
|
|
|
|
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
|
|
if !ok {
|
|
return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
|
|
}
|
|
info := new(oidc.UserInfo)
|
|
err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
|
|
if err != nil {
|
|
return nil, NewStatusError(err, http.StatusForbidden)
|
|
}
|
|
return NewResponse(info), nil
|
|
}
|
|
|
|
func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.Revocation")
|
|
defer span.End()
|
|
|
|
var subject string
|
|
doDecrypt := true
|
|
if r.Data.TokenTypeHint != "access_token" {
|
|
userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
|
|
if err != nil {
|
|
// An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
|
|
if !errors.Is(err, ErrInvalidRefreshToken) {
|
|
return nil, RevocationError(oidc.ErrServerError().WithParent(err))
|
|
}
|
|
} else {
|
|
r.Data.Token = tokenID
|
|
subject = userID
|
|
doDecrypt = false
|
|
}
|
|
}
|
|
if doDecrypt {
|
|
tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
|
|
if ok {
|
|
r.Data.Token = tokenID
|
|
subject = userID
|
|
}
|
|
}
|
|
if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
|
|
return nil, RevocationError(err)
|
|
}
|
|
return NewResponse(nil), nil
|
|
}
|
|
|
|
func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
|
|
ctx, span := tracer.Start(ctx, "LegacyServer.EndSession")
|
|
defer span.End()
|
|
|
|
session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
redirect := session.RedirectURI
|
|
if fromRequest, ok := s.provider.Storage().(CanTerminateSessionFromRequest); ok {
|
|
redirect, err = fromRequest.TerminateSessionFromRequest(ctx, session)
|
|
} else {
|
|
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewRedirect(redirect), nil
|
|
}
|