From 2902a81161de6d7953e4696b12c8f6b54c8d55ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 8 Sep 2023 10:42:27 +0300 Subject: [PATCH] intermediate commit with some methods implemented --- pkg/op/device.go | 38 ++++++---- pkg/op/discovery.go | 8 +-- pkg/op/discovery_test.go | 8 +-- pkg/op/error.go | 32 +++++++++ pkg/op/server.go | 25 ++++--- pkg/op/server_legacy.go | 149 ++++++++++++++++++++++++++++++++++++--- 6 files changed, 214 insertions(+), 46 deletions(-) diff --git a/pkg/op/device.go b/pkg/op/device.go index 029bed8..55d3c57 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -63,41 +63,51 @@ func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *htt } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { - storage, err := assertDeviceStorage(o.Storage()) - if err != nil { - return err - } - req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err } + response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o) + if err != nil { + return err + } + httphelper.MarshalJSON(w, response) + return nil +} + +func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return nil, err + } config := o.DeviceAuthorization() deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } expires := time.Now().Add(config.Lifetime) - err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) + err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes) if err != nil { - return err + return nil, NewStatusError(err, http.StatusInternalServerError) } var verification *url.URL if config.UserFormURL != "" { if verification, err = url.Parse(config.UserFormURL); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + return nil, NewStatusError(err, http.StatusInternalServerError) } } else { - if verification, err = url.Parse(IssuerFromContext(r.Context())); err != nil { - return oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + return nil, NewStatusError(err, http.StatusInternalServerError) } verification.Path = config.UserFormPath } @@ -112,9 +122,7 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide verification.RawQuery = "user_code=" + userCode response.VerificationURIComplete = verification.String() - - httphelper.MarshalJSON(w, response) - return nil + return response, nil } func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 782a279..d376032 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{ func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Discover(w, CreateDiscoveryConfig(r, c, s)) + Discover(w, CreateDiscoveryConfig(r.Context(), c, s)) } } @@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { httphelper.MarshalJSON(w, config) } -func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { - issuer := config.IssuerFromRequest(r) +func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) return &oidc.DiscoveryConfiguration{ Issuer: issuer, AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), @@ -49,7 +49,7 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), SubjectTypesSupported: SubjectTypes(config), - IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 3e95ec3..84e1216 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - request *http.Request - c op.Configuration - s op.DiscoverStorage + ctx context.Context + c op.Configuration + s op.DiscoverStorage } tests := []struct { name string @@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) + got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s) assert.Equal(t, tt.want, got) }) } diff --git a/pkg/op/error.go b/pkg/op/error.go index 9981fec..67278c6 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,6 +1,7 @@ package op import ( + "context" "net/http" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -66,3 +67,34 @@ func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slo logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) httphelper.MarshalJSONWithStatus(w, e, status) } + +func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { + e := oidc.DefaultToServerError(parent, parent.Error()) + logger = logger.With("oidc_error", e) + + if authReq == nil { + logger.Log(ctx, e.LogLevel(), "auth request") + return nil, NewStatusError(parent, http.StatusBadRequest) + } + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") + return nil, NewStatusError(parent, http.StatusBadRequest) + } + + e.State = authReq.GetState() + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) + if err != nil { + logger.ErrorContext(ctx, "auth response URL", "error", err) + return nil, NewStatusError(err, http.StatusBadRequest) + } + return NewRedirect(url), nil +} diff --git a/pkg/op/server.go b/pkg/op/server.go index 4965975..ab79a99 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -79,7 +79,7 @@ type Server interface { // DeviceAuthorization initiates the device authorization flow. // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 // The recommended Response Data type is [oidc.DeviceAuthorizationResponse]. - DeviceAuthorization(context.Context, *Request[oidc.DeviceAuthorizationRequest]) (*Response, error) + DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) // VerifyClient is called on most oauth/token handlers to authenticate, // using either a secret (POST, Basic) or assertion (JWT). @@ -137,7 +137,7 @@ type Server interface { // Introspect handles the OAuth 2.0 Token Introspection endpoint. // https://datatracker.ietf.org/doc/html/rfc7662 // The recommended Response Data type is [oidc.IntrospectionResponse]. - Introspect(context.Context, *Request[oidc.IntrospectionRequest]) (*Response, error) + Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo @@ -231,19 +231,18 @@ func (resp *Response) writeOut(w http.ResponseWriter) { // Redirect is a special response type which will // initiate a [http.StatusFound] redirect. -// The Params field will be encoded and set to the +// The Params fielde will be encoded and set to the // URL's RawQuery field before building the URL. -// -// If the RawQuery contains values that need to persist, -// the implementation should parse them into Params and -// add request specific values after. type Redirect struct { // Header map will be merged with the // header on the [http.ResponseWriter]. Header http.Header - URL url.URL - Params url.Values + URL string +} + +func NewRedirect(url string) *Redirect { + return &Redirect{URL: url} } type UnimplementedServer struct{} @@ -280,8 +279,8 @@ func (UnimplementedServer) Authorize(_ context.Context, r *Request[oidc.AuthRequ return nil, unimplementedError(r) } -func (UnimplementedServer) DeviceAuthorization(_ context.Context, r *Request[oidc.DeviceAuthorizationRequest]) (*Response, error) { - return nil, unimplementedError(r) +func (UnimplementedServer) DeviceAuthorization(_ context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + return nil, unimplementedError(r.Request) } func (UnimplementedServer) VerifyClient(_ context.Context, r *Request[ClientCredentials]) (Client, error) { @@ -312,8 +311,8 @@ func (UnimplementedServer) DeviceToken(_ context.Context, r *ClientRequest[oidc. return nil, unimplementedError(r.Request) } -func (UnimplementedServer) Introspect(_ context.Context, r *Request[oidc.IntrospectionRequest]) (*Response, error) { - return nil, unimplementedError(r) +func (UnimplementedServer) Introspect(_ context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + return nil, unimplementedError(r.Request) } func (UnimplementedServer) UserInfo(_ context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 5b16d0b..1200270 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -2,23 +2,97 @@ package op import ( "context" + "errors" + "net/http" "github.com/zitadel/oidc/v3/pkg/oidc" ) type LegacyServer struct { UnimplementedServer - op *Provider + provider OpenIDProvider + + readyProbes []ProbesFn +} + +func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + for _, probe := range s.readyProbes { + // shouldn't we run probes in Go routines? + if err := probe(ctx); err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + } + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse( + CreateDiscoveryConfig(ctx, s.provider, s.provider.Storage()), + ), nil +} + +var ( + ErrAuthReqMissingClientID = errors.New("auth request is missing client_id") + ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") +) + +func (s *LegacyServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + authReq := r.Data + if authReq.RequestParam != "" && s.provider.RequestObjectSupported() { + authReq, err = ParseRequestObject(ctx, authReq, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, NewStatusError(err, http.StatusBadRequest) + } + } + if authReq.ClientID == "" { + return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingClientID, s.provider.Encoder(), s.provider.Logger()) + } + if authReq.RedirectURI == "" { + return TryErrorRedirect(ctx, authReq, ErrAuthReqMissingRedirectURI, s.provider.Encoder(), s.provider.Logger()) + } + validation := ValidateAuthRequest + if validater, ok := s.provider.(AuthorizeValidator); ok { + validation = validater.ValidateAuthRequest + } + userID, err := validation(ctx, authReq, s.provider.Storage(), s.provider.IDTokenHintVerifier(ctx)) + if err != nil { + return TryErrorRedirect(ctx, authReq, err, s.provider.Encoder(), s.provider.Logger()) + } + if authReq.RequestParam != "" { + return TryErrorRedirect(ctx, authReq, oidc.ErrRequestNotSupported(), s.provider.Encoder(), s.provider.Logger()) + } + req, err := s.provider.Storage().CreateAuthRequest(ctx, authReq, userID) + if err != nil { + return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, req.GetClientID()) + if err != nil { + return TryErrorRedirect(ctx, authReq, oidc.DefaultToServerError(err, "unable to retrieve client by id"), s.provider.Encoder(), s.provider.Logger()) + } + return NewRedirect(client.LoginURL(req.GetID())), nil +} + +func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + return NewResponse(response), nil } func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { - if !s.op.AuthMethodPrivateKeyJWTSupported() { + jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() { return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") } - return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, s.op) + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger) } - client, err := s.op.Storage().GetClientByClientID(ctx, r.Data.ClientID) + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) if err != nil { return nil, oidc.ErrInvalidClient().WithParent(err) } @@ -29,12 +103,12 @@ func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCreden case oidc.AuthMethodPrivateKeyJWT: return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") case oidc.AuthMethodPost: - if !s.op.AuthMethodPostSupported() { + if !s.provider.AuthMethodPostSupported() { return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") } } - err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.op.storage) + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage()) if err != nil { return nil, err } @@ -43,7 +117,7 @@ func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCreden } func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { - authReq, err := AuthRequestByCode(ctx, s.op.storage, r.Data.Code) + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) if err != nil { return nil, err } @@ -52,7 +126,7 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A return nil, err } } - resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.op, true, r.Data.Code, "") + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") if err != nil { return nil, err } @@ -63,7 +137,7 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R if !ValidateGrantType(r.Client, oidc.GrantTypeRefreshToken) { return nil, oidc.ErrUnauthorizedClient() } - request, err := RefreshTokenRequestByRefreshToken(ctx, s.op.storage, r.Data.RefreshToken) + request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) if err != nil { return nil, err } @@ -73,9 +147,64 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { return nil, err } - resp, err := CreateTokenResponse(ctx, request, r.Client, s.op, true, "", r.Data.RefreshToken) + resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken) if err != nil { return nil, err } return NewResponse(resp), nil } + +func (s *LegacyServer) JWTProfile(_ context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (s *LegacyServer) TokenExchange(_ context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + return nil, unimplementedError(r.Request) +} + +func (s *LegacyServer) ClientCredentialsExchange(_ context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + return nil, unimplementedError(r.Request) +} + +func (s *LegacyServer) DeviceToken(_ context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + return nil, unimplementedError(r.Request) +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { + response := new(oidc.IntrospectionResponse) + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) + if !ok { + return NewResponse(response), nil + } + err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID()) + if err != nil { + return NewResponse(response), nil + } + response.Active = true + return NewResponse(response), nil +} + +func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) + if !ok { + return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) + } + info := new(oidc.UserInfo) + err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin")) + if err != nil { + return nil, NewStatusError(err, http.StatusForbidden) + } + return NewResponse(info), nil +} + +func (s *LegacyServer) Revocation(_ context.Context, r *Request[oidc.RevocationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (s *LegacyServer) EndSession(_ context.Context, r *Request[oidc.EndSessionRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (s *LegacyServer) Keys(_ context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +}