input validation and concrete handlers
This commit is contained in:
parent
81d42b061d
commit
aae3492f7b
3 changed files with 334 additions and 92 deletions
|
@ -43,10 +43,21 @@ type Server interface {
|
||||||
// The recommended Response Data type is [jose.JSOMWebKeySet].
|
// The recommended Response Data type is [jose.JSOMWebKeySet].
|
||||||
Keys(context.Context, *Request[struct{}]) (*Response, error)
|
Keys(context.Context, *Request[struct{}]) (*Response, error)
|
||||||
|
|
||||||
|
// VerifyAuthRequest verifies the Auth Request and
|
||||||
|
// adds the Client to the request.
|
||||||
|
//
|
||||||
|
// When the `request` field is populated with a
|
||||||
|
// "Request Object" JWT, it needs to be Validated
|
||||||
|
// and its claims overwrtite any fields in the AuthRequest.
|
||||||
|
// If the implementation does not support "Request Object",
|
||||||
|
// it MUST return an [oidc.ErrRequestNotSupported].
|
||||||
|
// https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
|
||||||
|
VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
|
||||||
|
|
||||||
// Authorize initiates the authorization flow and redirects to a login page.
|
// Authorize initiates the authorization flow and redirects to a login page.
|
||||||
// See the various https://openid.net/specs/openid-connect-core-1_0.html
|
// See the various https://openid.net/specs/openid-connect-core-1_0.html
|
||||||
// authorize endpoint sections (one for each type of flow).
|
// authorize endpoint sections (one for each type of flow).
|
||||||
Authorize(context.Context, *Request[oidc.AuthRequest]) (*Redirect, error)
|
Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
|
||||||
|
|
||||||
// AuthorizeCallback? Do we still need it?
|
// AuthorizeCallback? Do we still need it?
|
||||||
|
|
||||||
|
@ -259,7 +270,14 @@ func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Res
|
||||||
return nil, unimplementedError(r)
|
return nil, unimplementedError(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UnimplementedServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (*Redirect, error) {
|
func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||||
|
if r.Data.RequestParam != "" {
|
||||||
|
return nil, oidc.ErrRequestNotSupported()
|
||||||
|
}
|
||||||
|
return nil, unimplementedError(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
|
||||||
return nil, unimplementedError(r)
|
return nil, unimplementedError(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,14 +38,14 @@ func (s *webServer) createRouter(interceptors ...func(http.Handler) http.Handler
|
||||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||||
router.HandleFunc(s.endpoints.Authorization.Relative(), redirectHandler(s, s.server.Authorize))
|
router.HandleFunc(s.endpoints.Authorization.Relative(), s.authorizeHandler)
|
||||||
|
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), s.deviceAuthorizationHandler)
|
||||||
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
|
router.HandleFunc(s.endpoints.Token.Relative(), s.tokensHandler)
|
||||||
router.HandleFunc(s.endpoints.Introspection.Relative(), clientRequestHandler(s, s.server.Introspect))
|
router.HandleFunc(s.endpoints.Introspection.Relative(), s.introspectionHandler)
|
||||||
router.HandleFunc(s.endpoints.Userinfo.Relative(), requestHandler(s, s.server.UserInfo))
|
router.HandleFunc(s.endpoints.Userinfo.Relative(), s.userInfoHandler)
|
||||||
router.HandleFunc(s.endpoints.Revocation.Relative(), clientRequestHandler(s, s.server.Revocation))
|
router.HandleFunc(s.endpoints.Revocation.Relative(), s.revokationHandler)
|
||||||
router.HandleFunc(s.endpoints.EndSession.Relative(), redirectHandler(s, s.server.EndSession))
|
router.HandleFunc(s.endpoints.EndSession.Relative(), s.endSessionHandler)
|
||||||
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
|
router.HandleFunc(s.endpoints.JwksURI.Relative(), simpleHandler(s, s.server.Keys))
|
||||||
router.HandleFunc(s.endpoints.DeviceAuthorization.Relative(), clientRequestHandler(s, s.server.DeviceAuthorization))
|
|
||||||
s.Handler = router
|
s.Handler = router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,10 +76,77 @@ func (s *webServer) verifyRequestClient(r *http.Request) (Client, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirect, err := s.authorize(r.Context(), newRequest(r, request))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirect.writeOut(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||||
|
cr, err := s.server.VerifyAuthRequest(ctx, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authReq := cr.Data
|
||||||
|
if authReq.RedirectURI == "" {
|
||||||
|
return nil, ErrAuthReqMissingRedirectURI
|
||||||
|
}
|
||||||
|
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.server.Authorize(ctx, cr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
client, err := s.verifyRequestClient(r)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, slog.Default())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
grantType := oidc.GrantType(r.Form.Get("grant_type"))
|
||||||
|
if grantType == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
||||||
|
return
|
||||||
|
}
|
||||||
if grantType == oidc.GrantTypeBearer {
|
if grantType == oidc.GrantTypeBearer {
|
||||||
callRequestMethod(s, w, r, s.server.JWTProfile)
|
s.jwtProfileHandler(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,28 +155,229 @@ func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
WriteError(w, r, err, slog.Default())
|
WriteError(w, r, err, slog.Default())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !ValidateGrantType(client, grantType) {
|
||||||
|
WriteError(w, r, oidc.ErrUnauthorizedClient(), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch grantType {
|
switch grantType {
|
||||||
case oidc.GrantTypeCode:
|
case oidc.GrantTypeCode:
|
||||||
callClientMethod(s, w, r, client, s.server.CodeExchange)
|
s.codeExchangeHandler(w, r, client)
|
||||||
case oidc.GrantTypeRefreshToken:
|
case oidc.GrantTypeRefreshToken:
|
||||||
callClientMethod(s, w, r, client, s.server.RefreshToken)
|
s.refreshTokenHandler(w, r, client)
|
||||||
case oidc.GrantTypeTokenExchange:
|
case oidc.GrantTypeTokenExchange:
|
||||||
callClientMethod(s, w, r, client, s.server.TokenExchange)
|
s.tokenExchangeHandler(w, r, client)
|
||||||
case oidc.GrantTypeClientCredentials:
|
case oidc.GrantTypeClientCredentials:
|
||||||
callClientMethod(s, w, r, client, s.server.ClientCredentialsExchange)
|
s.clientCredentialsHandler(w, r, client)
|
||||||
case oidc.GrantTypeDeviceCode:
|
case oidc.GrantTypeDeviceCode:
|
||||||
callClientMethod(s, w, r, client, s.server.DeviceToken)
|
s.deviceTokenHandler(w, r, client)
|
||||||
case "":
|
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), slog.Default())
|
|
||||||
default:
|
default:
|
||||||
WriteError(w, r, unimplementedGrantError(grantType), slog.Default())
|
WriteError(w, r, unimplementedGrantError(grantType), s.logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type requestMethod[T any] func(context.Context, *Request[T]) (*Response, error)
|
func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.Assertion == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFunc {
|
func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
|
request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.Code == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.RedirectURI == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
|
request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.RefreshToken == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
|
request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.SubjectToken == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.SubjectTokenType == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !request.SubjectTokenType.IsSupported() {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
|
request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: is a public client allowed here?
|
||||||
|
|
||||||
|
resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||||
|
request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.DeviceCode == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
client, err := s.verifyRequestClient(r)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, slog.Default())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.Token == "" {
|
||||||
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
|
||||||
|
if err != nil || request.AccessToken == "" {
|
||||||
|
err = AsStatusError(
|
||||||
|
oidc.ErrInvalidRequest().WithParent(err).WithDescription("access token missing"),
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
)
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) revokationHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
client, err := s.verifyRequestClient(r)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, slog.Default())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
|
||||||
|
if err != nil {
|
||||||
|
WriteError(w, r, err, s.logger)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.writeOut(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.logger)
|
||||||
|
@ -124,72 +392,8 @@ func simpleHandler(s *webServer, method requestMethod[struct{}]) http.HandlerFun
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestHandler[T any](s *webServer, method requestMethod[T]) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callRequestMethod(s, w, r, method)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func callRequestMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, method requestMethod[T]) {
|
|
||||||
request, err := decodeRequest[T](s.decoder, r, false)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp, err := method(r.Context(), newRequest[T](r, request))
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.writeOut(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
type redirectMethod[T any] func(context.Context, *Request[T]) (*Redirect, error)
|
|
||||||
|
|
||||||
func redirectHandler[T any](s *webServer, method redirectMethod[T]) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
req, err := decodeRequest[T](s.decoder, r, false)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirect, err := method(r.Context(), newRequest(r, req))
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirect.writeOut(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientMethod[T any] func(context.Context, *ClientRequest[T]) (*Response, error)
|
|
||||||
|
|
||||||
func clientRequestHandler[T any](s *webServer, method clientMethod[T]) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
client, err := s.verifyRequestClient(r)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, slog.Default())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
callClientMethod(s, w, r, client, method)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func callClientMethod[T any](s *webServer, w http.ResponseWriter, r *http.Request, client Client, method clientMethod[T]) {
|
|
||||||
request, err := decodeRequest[T](s.decoder, r, false)
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp, err := method(r.Context(), newClientRequest[T](r, request, client))
|
|
||||||
if err != nil {
|
|
||||||
WriteError(w, r, err, s.logger)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.writeOut(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
|
||||||
|
dst := new(R)
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||||
}
|
}
|
||||||
|
@ -197,9 +401,8 @@ func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly
|
||||||
if postOnly {
|
if postOnly {
|
||||||
form = r.PostForm
|
form = r.PostForm
|
||||||
}
|
}
|
||||||
request := new(R)
|
if err := decoder.Decode(dst, form); err != nil {
|
||||||
if err := decoder.Decode(request, form); err != nil {
|
|
||||||
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
|
||||||
}
|
}
|
||||||
return request, nil
|
return dst, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,8 +49,32 @@ var (
|
||||||
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
|
ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *LegacyServer) Authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
|
func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
|
||||||
userID, err := ValidateAuthRequestV2(ctx, r.Data, s.provider)
|
if r.Data.RequestParam != "" {
|
||||||
|
if !s.provider.RequestObjectSupported() {
|
||||||
|
return nil, oidc.ErrRequestNotSupported()
|
||||||
|
}
|
||||||
|
err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Data.ClientID == "" {
|
||||||
|
return nil, ErrAuthReqMissingClientID
|
||||||
|
}
|
||||||
|
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ClientRequest[oidc.AuthRequest]{
|
||||||
|
Request: r,
|
||||||
|
Client: client,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
|
||||||
|
userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -126,9 +150,6 @@ func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.R
|
||||||
if !s.provider.GrantTypeRefreshTokenSupported() {
|
if !s.provider.GrantTypeRefreshTokenSupported() {
|
||||||
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
|
||||||
}
|
}
|
||||||
if !ValidateGrantType(r.Client, oidc.GrantTypeRefreshToken) {
|
|
||||||
return nil, oidc.ErrUnauthorizedClient()
|
|
||||||
}
|
|
||||||
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
|
request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue