From d6a9c0bbb92d0561e3666aab992396ba57f17823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 4 Sep 2023 23:33:51 +0300 Subject: [PATCH] first draft of a new server interface --- pkg/op/client.go | 8 ++ pkg/op/op.go | 6 +- pkg/op/probes.go | 4 +- pkg/op/server.go | 269 ++++++++++++++++++++++++++++++++++++++++ pkg/op/server_http.go | 109 ++++++++++++++++ pkg/op/server_legacy.go | 81 ++++++++++++ pkg/op/server_test.go | 5 + pkg/op/token_code.go | 2 +- pkg/op/token_request.go | 6 +- 9 files changed, 481 insertions(+), 9 deletions(-) create mode 100644 pkg/op/server.go create mode 100644 pkg/op/server_http.go create mode 100644 pkg/op/server_legacy.go create mode 100644 pkg/op/server_test.go diff --git a/pkg/op/client.go b/pkg/op/client.go index d01845f..525af8d 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -180,3 +180,11 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au } return data.ClientID, false, nil } + +type ClientCredentials struct { + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` // Client secret from request body + ClientSecretBasic string `schema:"-"` // Obtained from http request + ClientAssertion string `schema:"client_assertion"` // JWT + ClientAssertionType string `schema:"client_assertion_type"` +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 0175d7f..5b318e3 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -32,7 +32,7 @@ const ( ) var ( - DefaultEndpoints = &endpoints{ + DefaultEndpoints = &Endpoints{ Authorization: NewEndpoint(defaultAuthorizationEndpoint), Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), @@ -131,7 +131,7 @@ type Config struct { DeviceAuthorization DeviceAuthorizationConfig } -type endpoints struct { +type Endpoints struct { Authorization Endpoint Token Endpoint Introspection Endpoint @@ -212,7 +212,7 @@ type Provider struct { config *Config issuer IssuerFromRequest insecure bool - endpoints *endpoints + endpoints *Endpoints storage Storage keySet *openIDKeySet crypto Crypto diff --git a/pkg/op/probes.go b/pkg/op/probes.go index 9ef5bb5..cb3853d 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - httphelper.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, Status{"ok"}) } -type status struct { +type Status struct { Status string `json:"status,omitempty"` } diff --git a/pkg/op/server.go b/pkg/op/server.go new file mode 100644 index 0000000..218d338 --- /dev/null +++ b/pkg/op/server.go @@ -0,0 +1,269 @@ +package op + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + + jose "github.com/go-jose/go-jose/v3" + "github.com/muhlemmer/gu" + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type StatusError struct { + parent error + statusCode int +} + +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +type Request[T any] struct { + Method string + URL *url.URL + Header http.Header + Form url.Values + Data *T +} + +func newRequest[T any](r *http.Request, data *T) *Request[T] { + return &Request[T]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: data, + } +} + +type ClientRequest[T any] struct { + *Request[T] + Client Client +} + +func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] { + return &ClientRequest[T]{ + Request: newRequest[T](r, data), + Client: client, + } +} + +type Response[T any] struct { + Header http.Header + Data *T +} + +func NewResponse[T any](data *T) *Response[T] { + return &Response[T]{ + Data: data, + } +} + +func (resp *Response[T]) writeOut(w http.ResponseWriter) { + gu.MapMerge(resp.Header, w.Header()) + json.NewEncoder(w).Encode(resp.Data) +} + +type Server interface { + // Health should return a status of "ok" once the Server is listining. + Health(context.Context, *Request[struct{}]) (*Response[Status], error) + + // Ready should return a status of "ok" once all dependecies, + // such as database storage are ready. + // An error can be returned to explain what is not ready. + Ready(context.Context, *Request[struct{}]) (*Response[Status], error) + + // Discovery return the OpenID Provider Configuration Information for this server. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + Discovery(context.Context, *Request[struct{}]) (*Response[oidc.DiscoveryConfiguration], error) + + // Authorize initiates the authorization flow and redirects to a login page. + // See the various https://openid.net/specs/openid-connect-core-1_0.html + // authorize endpoint sections (one for each type of flow). + Authorize(context.Context, *Request[oidc.AuthRequest]) (*Response[url.URL], error) + + // AuthorizeCallback? Do we still need it? + + // DeviceAuthorization initiates the device authorization flow. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 + DeviceAuthorization(context.Context, *Request[oidc.DeviceAuthorizationRequest]) (*Response[oidc.DeviceAuthorizationResponse], error) + + // VerifyClient is called on most oauth/token handlers to authenticate, + // using either a secret (POST, Basic) or assertion (JWT). + // If no secrets are provided, the client must be public. + // This method is called before each method that takes a + // [ClientRequest] argument. + VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error) + + // CodeExchange returns Tokens after an authorization code + // is obtained in a succesfull Authorize flow. + // It is called by the Token endpoint handler when + // grant_type has the value authorization_code + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response[oidc.AccessTokenResponse], error) + + // RefreshToken returns new Tokens after verifying a Refresh token. + // It is called by the Token endpoint handler when + // grant_type has the value refresh_token + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens + RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response[oidc.AccessTokenResponse], error) + + // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer + // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 + JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response[oidc.AccessTokenResponse], error) + + // TokenExchange handles the OAuth 2.0 token exchange grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange + // https://datatracker.ietf.org/doc/html/rfc8693 + TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response[oidc.AccessTokenResponse], error) + + // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant + // It is called by the Token endpoint handler when + // grant_type has the value client_credentials + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response[oidc.AccessTokenResponse], error) + + // DeviceToken handles the OAuth 2.0 Device Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:device_code. + // It is typically called in a polling fashion and appropiate errors + // should be returned to signal authorization_pending or access_denied etc. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4, + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5. + DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response[oidc.AccessTokenResponse], error) + + // Introspect handles the OAuth 2.0 Token Introspection endpoint. + // https://datatracker.ietf.org/doc/html/rfc7662 + Introspect(context.Context, *Request[oidc.IntrospectionRequest]) (*Response[oidc.IntrospectionResponse], error) + + // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response[oidc.UserInfo], error) + + // Revocation handles token revocation using an access or refresh token. + // https://datatracker.ietf.org/doc/html/rfc7009 + Revocation(context.Context, *Request[oidc.RevocationRequest]) (*Response[struct{}], error) + + // EndSession handles the OpenID Connect RP-Initiated Logout. + // https://openid.net/specs/openid-connect-rpinitiated-1_0.html + EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Response[struct{}], error) + + // Keys serves the JWK set which the client can use verify signatures from the op. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key. + Keys(context.Context, *Request[struct{}]) (*Response[jose.JSONWebKeySet], error) + + mustImpl() +} + +type UnimplementedServer struct{} + +// UnimplementedStatusCode is the statuscode returned for methods +// that are not yet implemented. +// Note that this means methods in the sense of the Go interface, +// and not http methods covered by "501 Not Implemented". +var UnimplementedStatusCode = http.StatusNotFound + +func unimplementedError[T any](r *Request[T]) StatusError { + err := oidc.ErrServerError().WithDescription(fmt.Sprintf("%s not implemented on this server", r.URL.Path)) + return StatusError{ + parent: err, + statusCode: UnimplementedStatusCode, + } +} + +func (UnimplementedServer) mustImpl() {} + +func (UnimplementedServer) Health(_ context.Context, r *Request[struct{}]) (*Response[Status], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Ready(_ context.Context, r *Request[struct{}]) (*Response[Status], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Discovery(_ context.Context, r *Request[struct{}]) (*Response[oidc.DiscoveryConfiguration], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(_ context.Context, r *Request[oidc.AuthRequest]) (*Response[url.URL], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) DeviceAuthorization(_ context.Context, r *Request[oidc.DeviceAuthorizationRequest]) (*Response[oidc.DeviceAuthorizationResponse], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyClient(_ context.Context, r *Request[ClientCredentials]) (Client, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) CodeExchange(_ context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r.Request) +} + +func (UnimplementedServer) RefreshToken(_ context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r.Request) +} + +func (UnimplementedServer) JWTProfile(_ context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) TokenExchange(_ context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r.Request) +} + +func (UnimplementedServer) ClientCredentialsExchange(_ context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r.Request) +} + +func (UnimplementedServer) DeviceToken(_ context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response[oidc.AccessTokenResponse], error) { + return nil, unimplementedError(r.Request) +} + +func (UnimplementedServer) Introspect(_ context.Context, r *Request[oidc.IntrospectionRequest]) (*Response[oidc.IntrospectionResponse], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) UserInfo(_ context.Context, r *Request[oidc.UserInfoRequest]) (*Response[oidc.UserInfo], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Revocation(_ context.Context, r *Request[oidc.RevocationRequest]) (*Response[struct{}], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) EndSession(_ context.Context, r *Request[oidc.EndSessionRequest]) (*Response[struct{}], error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Keys(_ context.Context, r *Request[struct{}]) (*Response[jose.JSONWebKeySet], error) { + return nil, unimplementedError(r) +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go new file mode 100644 index 0000000..e727626 --- /dev/null +++ b/pkg/op/server_http.go @@ -0,0 +1,109 @@ +package op + +import ( + "net/http" + + "github.com/go-chi/chi" + "github.com/rs/cors" + httphelper "github.com/zitadel/oidc/v3/pkg/http" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/exp/slog" +) + +type webServer struct { + http.Handler + decoder httphelper.Decoder + server Server + logger *slog.Logger +} + +func (s *webServer) createRouter(endpoints *Endpoints, interceptors ...func(http.Handler) http.Handler) chi.Router { + router := chi.NewRouter() + router.Use(cors.New(defaultCORSOptions).Handler) + router.Use(interceptors...) + router.HandleFunc(healthEndpoint, healthHandler) + //router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) + //router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) + //router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) + //router.HandleFunc(authCallbackPath(o), authorizeCallbackHandler(o)) + router.HandleFunc(endpoints.Token.Relative(), s.handleToken) + //router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) + //router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) + //router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) + //router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o)) + //router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) + //router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o)) + return router +} + +func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + clientCredentials := new(ClientCredentials) + if err := s.decoder.Decode(clientCredentials, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + // Basic auth takes precedence, so if set it overwrites the form data. + if clientID, clientSecret, ok := r.BasicAuth(); ok { + clientCredentials.ClientID, clientCredentials.ClientSecret = clientID, clientSecret + } + + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: clientCredentials, + }) +} + +func (s *webServer) handleToken(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + RequestError(w, r, err, slog.Default()) + return + } + + grantType := oidc.GrantType(r.Form.Get("grant_type")) + var handle func(w http.ResponseWriter, r *http.Request, client Client) + switch grantType { + case oidc.GrantTypeCode: + handle = s.handleCodeExchange + case oidc.GrantTypeRefreshToken: + handle = s.handleRefreshToken + case "": + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) + return + default: + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), slog.Default()) + return + } + + handle(w, r, client) +} + +func (s *webServer) handleCodeExchange(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[*oidc.AccessTokenRequest](s.decoder, r.Form) + if err != nil { + RequestError(w, r, err, s.logger) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + RequestError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, client Client) { + +} + +func decodeRequest[R any](decoder httphelper.Decoder, form map[string][]string) (request R, err error) { + if err := decoder.Decode(&request, form); err != nil { + return request, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + return request, nil +} diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go new file mode 100644 index 0000000..684ceed --- /dev/null +++ b/pkg/op/server_legacy.go @@ -0,0 +1,81 @@ +package op + +import ( + "context" + + "github.com/zitadel/oidc/v3/pkg/oidc" +) + +type LegacyServer struct { + UnimplementedServer + op *Provider +} + +func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + if !s.op.AuthMethodPrivateKeyJWTSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, s.op) + } + client, err := s.op.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithParent(err) + } + + switch client.AuthMethod() { + case oidc.AuthMethodNone: + return client, nil + case oidc.AuthMethodPrivateKeyJWT: + return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") + case oidc.AuthMethodPost: + if !s.op.AuthMethodPostSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + } + + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.op.storage) + if err != nil { + return nil, err + } + + return client, nil +} + +func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response[oidc.AccessTokenResponse], error) { + authReq, err := AuthRequestByCode(ctx, s.op.storage, r.Data.Code) + if err != nil { + return nil, err + } + if r.Client.AuthMethod() == oidc.AuthMethodNone { + if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { + return nil, err + } + } + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.op, true, r.Data.Code, "") + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response[oidc.AccessTokenResponse], error) { + if !ValidateGrantType(r.Client, oidc.GrantTypeRefreshToken) { + return nil, oidc.ErrUnauthorizedClient() + } + request, err := RefreshTokenRequestByRefreshToken(ctx, s.op.storage, r.Data.RefreshToken) + if err != nil { + return nil, err + } + if r.Client.GetID() != request.GetClientID() { + return nil, oidc.ErrInvalidGrant() + } + if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { + return nil, err + } + resp, err := CreateTokenResponse(ctx, request, r.Client, s.op, true, "", r.Data.RefreshToken) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go new file mode 100644 index 0000000..0cad8fd --- /dev/null +++ b/pkg/op/server_test.go @@ -0,0 +1,5 @@ +package op + +// implementation check +var _ Server = &UnimplementedServer{} +var _ Server = &LegacyServer{} diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index baf377b..371e1d4 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -88,7 +88,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge()) + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index 0df2fce..b810633 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -117,11 +117,11 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { - if tokenReq.CodeVerifier == "" { +func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if codeVerifier == "" { return oidc.ErrInvalidRequest().WithDescription("code_challenge required") } - if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { + if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") } return nil