diff --git a/pkg/op/server.go b/pkg/op/server.go index ab79a99..90bc664 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -254,13 +254,18 @@ type UnimplementedServer struct{} var UnimplementedStatusCode = http.StatusNotFound func unimplementedError[T any](r *Request[T]) StatusError { - err := oidc.ErrServerError().WithDescription(fmt.Sprintf("%s not implemented on this server", r.URL.Path)) + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.URL.Path) return StatusError{ parent: err, statusCode: UnimplementedStatusCode, } } +func unimplementedGrantError(gt oidc.GrantType) StatusError { + err := oidc.ErrUnsupportedGrantType().WithDescription("%s grant not supported", gt) + return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +} + func (UnimplementedServer) mustImpl() {} func (UnimplementedServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { @@ -288,27 +293,27 @@ func (UnimplementedServer) VerifyClient(_ context.Context, r *Request[ClientCred } func (UnimplementedServer) CodeExchange(_ context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedGrantError(oidc.GrantTypeCode) } func (UnimplementedServer) RefreshToken(_ context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) } func (UnimplementedServer) JWTProfile(_ context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { - return nil, unimplementedError(r) + return nil, unimplementedGrantError(oidc.GrantTypeBearer) } func (UnimplementedServer) TokenExchange(_ context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) } func (UnimplementedServer) ClientCredentialsExchange(_ context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) } func (UnimplementedServer) DeviceToken(_ context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) } func (UnimplementedServer) Introspect(_ context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 1200270..1c810b4 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "time" "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -15,11 +16,13 @@ type LegacyServer struct { readyProbes []ProbesFn } -func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { +type none = struct{} + +func (s *LegacyServer) Health(_ context.Context, r *Request[none]) (*Response, error) { return NewResponse(Status{Status: "ok"}), nil } -func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { +func (s *LegacyServer) Ready(ctx context.Context, r *Request[none]) (*Response, error) { for _, probe := range s.readyProbes { // shouldn't we run probes in Go routines? if err := probe(ctx); err != nil { @@ -29,7 +32,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon return NewResponse(Status{Status: "ok"}), nil } -func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[none]) (*Response, error) { return NewResponse( CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()), ), nil @@ -134,6 +137,9 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A } func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeRefreshTokenSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) + } if !ValidateGrantType(r.Client, oidc.GrantTypeRefreshToken) { return nil, oidc.ErrUnauthorizedClient() } @@ -154,20 +160,81 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R return NewResponse(resp), nil } -func (s *LegacyServer) JWTProfile(_ context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { - return nil, unimplementedError(r) +func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) + } + tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) + if err != nil { + return nil, err + } + + tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger) + if err != nil { + return nil, err + } + return NewResponse(resp), nil } -func (s *LegacyServer) TokenExchange(_ context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) +func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + if !s.provider.GrantTypeTokenExchangeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider) + if err != nil { + return nil, err + } + return NewResponse(resp), nil } -func (s *LegacyServer) ClientCredentialsExchange(_ context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) +func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) + } + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil } -func (s *LegacyServer) DeviceToken(_ context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { - return nil, unimplementedError(r.Request) +func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + if !s.provider.GrantTypeClientCredentialsSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) + } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + + state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + if err != nil { + return nil, err + } + tokenRequest := &deviceAccessTokenRequest{ + subject: state.Subject, + audience: []string{r.Client.GetID()}, + scopes: state.Scopes, + } + resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil } func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 21db134..5156741 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -197,12 +197,6 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") } - storage := exchanger.Storage() - teStorage, ok := storage.(TokenExchangeStorage) - if !ok { - return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported") - } - client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger) if err != nil { return nil, nil, err @@ -220,10 +214,28 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") } + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + +func CreateTokenExchangeRequest( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, error) { + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") } var ( @@ -234,7 +246,7 @@ func ValidateTokenExchangeRequest( exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") } } @@ -258,17 +270,17 @@ func ValidateTokenExchangeRequest( authTime: time.Now(), } - err = teStorage.ValidateTokenExchangeRequest(ctx, req) + err := teStorage.ValidateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } err = teStorage.CreateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } - return req, client, nil + return req, nil } func GetTokenIDAndSubjectFromToken(