diff --git a/pkg/op/server.go b/pkg/op/server.go index 0862f32..a836703 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -43,10 +43,21 @@ type Server interface { // The recommended Response Data type is [jose.JSOMWebKeySet]. Keys(context.Context, *Request[struct{}]) (*Response, error) + // VerifyAuthRequest verifies the Auth Request and + // adds the Client to the request. + // + // When the `request` field is populated with a + // "Request Object" JWT, it needs to be Validated + // and its claims overwrtite any fields in the AuthRequest. + // If the implementation does not support "Request Object", + // it MUST return an [oidc.ErrRequestNotSupported]. + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) + // Authorize initiates the authorization flow and redirects to a login page. // See the various https://openid.net/specs/openid-connect-core-1_0.html // authorize endpoint sections (one for each type of flow). - Authorize(context.Context, *Request[oidc.AuthRequest]) (*Redirect, error) + Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error) // AuthorizeCallback? Do we still need it? @@ -259,7 +270,14 @@ func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Res return nil, unimplementedError(r) } -func (UnimplementedServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (*Redirect, error) { +func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + return nil, oidc.ErrRequestNotSupported() + } + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) { return nil, unimplementedError(r) } diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index dedbfe5..cfd39ca 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -38,14 +38,14 @@ func (s *webServer) createRouter(interceptors ...func(http.Handler) http.Handler router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) - router.HandleFunc(s.endpoints.Authorization.Relative(), redirectHandler(s, s.server.Authorize)) + router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler) + router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.deviceAuthorizationHandler) router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler) - router.HandleFunc(s.endpoints.Introspection.Relative(), clientRequestHandler(s, s.server.Introspect)) - router.HandleFunc(s.endpoints.Userinfo.Relative(), requestHandler(s, s.server.UserInfo)) - router.HandleFunc(s.endpoints.Revocation.Relative(), clientRequestHandler(s, s.server.Revocation)) - router.HandleFunc(s.endpoints.EndSession.Relative(), redirectHandler(s, s.server.EndSession)) + router.HandleFunc(s.endpoints.Introspection.Relative(), s.introspectionHandler) + router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler) + router.HandleFunc(s.endpoints.Revocation.Relative(), s.revokationHandler) + router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler) router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys)) - router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), clientRequestHandler(s, s.server.DeviceAuthorization)) s.Handler = router } @@ -76,10 +76,77 @@ func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) { }) } +func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + redirect, err := s.authorize(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + redirect.writeOut(w, r) +} + +func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + cr, err := s.server.VerifyAuthRequest(ctx, r) + if err != nil { + return nil, err + } + authReq := cr.Data + if authReq.RedirectURI == "" { + return nil, ErrAuthReqMissingRedirectURI + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return nil, err + } + authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes) + if err != nil { + return nil, err + } + if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return nil, err + } + if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil { + return nil, err + } + return s.server.Authorize(ctx, cr) +} + +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, slog.Default()) + return + } + request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) + return + } grantType := oidc.GrantType(r.Form.Get("grant_type")) + if grantType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) + return + } if grantType == oidc.GrantTypeBearer { - callRequestMethod(s, w, r, s.server.JWTProfile) + s.jwtProfileHandler(w, r) return } @@ -88,28 +155,229 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { WriteError(w, r, err, slog.Default()) return } + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient(), s.logger) + return + } switch grantType { case oidc.GrantTypeCode: - callClientMethod(s, w, r, client, s.server.CodeExchange) + s.codeExchangeHandler(w, r, client) case oidc.GrantTypeRefreshToken: - callClientMethod(s, w, r, client, s.server.RefreshToken) + s.refreshTokenHandler(w, r, client) case oidc.GrantTypeTokenExchange: - callClientMethod(s, w, r, client, s.server.TokenExchange) + s.tokenExchangeHandler(w, r, client) case oidc.GrantTypeClientCredentials: - callClientMethod(s, w, r, client, s.server.ClientCredentialsExchange) + s.clientCredentialsHandler(w, r, client) case oidc.GrantTypeDeviceCode: - callClientMethod(s, w, r, client, s.server.DeviceToken) - case "": - WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default()) + s.deviceTokenHandler(w, r, client) default: - WriteError(w, r, unimplementedGrantError(grantType), slog.Default()) + WriteError(w, r, unimplementedGrantError(grantType), s.logger) } } -type requestMethod[T any] func(context.Context, *Request[T]) (*Response, error) +func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.Assertion == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger) + return + } + resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} -func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFunc { +func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.Code == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger) + return + } + if request.RedirectURI == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.RefreshToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger) + return + } + resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.SubjectToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + return + } + if request.SubjectTokenType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + return + } + if !request.SubjectTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + return + } + if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger) + return + } + resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + + // TODO: is a public client allowed here? + + resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.DeviceCode == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger) + return + } + resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, slog.Default()) + return + } + request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger) + return + } + resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) + if err != nil || request.AccessToken == "" { + err = AsStatusError( + oidc.ErrInvalidRequest().WithParent(err).WithDescription("access token missing"), + http.StatusUnauthorized, + ) + WriteError(w, r, err, s.logger) + return + } + resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) revokationHandler(w http.ResponseWriter, r *http.Request) { + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, slog.Default()) + return + } + request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w) +} + +func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.logger) + return + } + resp.writeOut(w, r) +} + +func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger) @@ -124,72 +392,8 @@ func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFun } } -func requestHandler[T any](s *webServer, method requestMethod[T]) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - callRequestMethod(s, w, r, method) - } -} - -func callRequestMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, method requestMethod[T]) { - request, err := decodeRequest[T](s.decoder, r, false) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - resp, err := method(r.Context(), newRequest[T](r, request)) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - resp.writeOut(w) -} - -type redirectMethod[T any] func(context.Context, *Request[T]) (*Redirect, error) - -func redirectHandler[T any](s *webServer, method redirectMethod[T]) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - req, err := decodeRequest[T](s.decoder, r, false) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - redirect, err := method(r.Context(), newRequest(r, req)) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - redirect.writeOut(w, r) - } -} - -type clientMethod[T any] func(context.Context, *ClientRequest[T]) (*Response, error) - -func clientRequestHandler[T any](s *webServer, method clientMethod[T]) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - client, err := s.verifyRequestClient(r) - if err != nil { - WriteError(w, r, err, slog.Default()) - return - } - callClientMethod(s, w, r, client, method) - } -} - -func callClientMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, client Client, method clientMethod[T]) { - request, err := decodeRequest[T](s.decoder, r, false) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - resp, err := method(r.Context(), newClientRequest[T](r, request, client)) - if err != nil { - WriteError(w, r, err, s.logger) - return - } - resp.writeOut(w) -} - func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + dst := new(R) if err := r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } @@ -197,9 +401,8 @@ func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly if postOnly { form = r.PostForm } - request := new(R) - if err := decoder.Decode(request, form); err != nil { + if err := decoder.Decode(dst, form); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } - return request, nil + return dst, nil } diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 03caf33..6e6fa53 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -49,8 +49,32 @@ var ( ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") ) -func (s *LegacyServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { - userID, err := ValidateAuthRequestV2(ctx, r.Data, s.provider) +func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + if !s.provider.RequestObjectSupported() { + return nil, oidc.ErrRequestNotSupported() + } + err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, err + } + } + if r.Data.ClientID == "" { + return nil, ErrAuthReqMissingClientID + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: client, + }, nil +} + +func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) if err != nil { return nil, err } @@ -126,9 +150,6 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R if !s.provider.GrantTypeRefreshTokenSupported() { return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) } - if !ValidateGrantType(r.Client, oidc.GrantTypeRefreshToken) { - return nil, oidc.ErrUnauthorizedClient() - } request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) if err != nil { return nil, err