Merge branch 'main' into fixClientAssertion

This commit is contained in:
David Sharnoff 2024-08-22 16:04:42 -07:00 committed by GitHub
commit b8f584429d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 57 additions and 28 deletions

View file

@ -56,6 +56,7 @@ func main() {
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithHTTPClient(client),
rp.WithLogger(logger),
rp.WithSigningAlgsFromDiscovery(),
}
if clientSecret == "" {
options = append(options, rp.WithPKCE(cookieHandler))

View file

@ -21,6 +21,14 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
// There is no published spec for this yet, but we have confirmation it will get published.
// There is consensus here: https://bitbucket.org/openid/connect/issues/1125/_hash-algorithm-for-eddsa-id-tokens
// Currently Go and go-jose only supports the ed25519 curve key for EdDSA, so we can safely assume sha512 here.
// It is unlikely ed448 will ever be supported: https://github.com/golang/go/issues/29390
case jose.EdDSA:
return sha512.New(), nil
default:
return nil, fmt.Errorf("%w: %q", ErrUnsupportedAlgorithm, sigAlgorithm)
}

View file

@ -6,6 +6,7 @@ import (
"crypto/ed25519"
"crypto/rsa"
"errors"
"strings"
jose "github.com/go-jose/go-jose/v4"
)
@ -92,17 +93,17 @@ func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (k
}
func algToKeyType(key any, alg string) bool {
switch alg[0] {
case 'R', 'P':
if strings.HasPrefix(alg, "RS") || strings.HasPrefix(alg, "PS") {
_, ok := key.(*rsa.PublicKey)
return ok
case 'E':
}
if strings.HasPrefix(alg, "ES") {
_, ok := key.(*ecdsa.PublicKey)
return ok
case 'O':
_, ok := key.(*ed25519.PublicKey)
return ok
default:
return false
}
if alg == string(jose.EdDSA) {
_, ok := key.(ed25519.PublicKey)
return ok
}
return false
}

View file

@ -83,19 +83,27 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer)
AuthRequestError(w, r, nil, err, authorizer)
return
}
}
if authReq.ClientID == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer)
AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing client_id"), authorizer)
return
}
if authReq.RedirectURI == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
return
}
validation := ValidateAuthRequest
var client Client
validation := func(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
client, err = authorizer.Storage().GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
}
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
}
if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
}
@ -113,11 +121,6 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
return
}
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer)
return
}
RedirectToLogin(req.GetID(), client, w, r)
}
@ -212,26 +215,37 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
authReq.RequestParam = ""
}
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed.
//
// Deprecated: Use [ValidateAuthRequestClient] to prevent querying for the Client twice.
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
ctx, span := tracer.Start(ctx, "ValidateAuthRequest")
defer span.End()
client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
}
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
}
// ValidateAuthRequestClient validates the Auth request against the passed client.
// If id_token_hint is part of the request, the subject of the token is returned.
func ValidateAuthRequestClient(ctx context.Context, authReq *oidc.AuthRequest, client Client, verifier *IDTokenHintVerifier) (sub string, err error) {
ctx, span := tracer.Start(ctx, "ValidateAuthRequestClient")
defer span.End()
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return "", err
}
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return "", err
}
client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return "", err
}
if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil {
return "", err
}

View file

@ -428,7 +428,8 @@ func TestTryErrorRedirect(t *testing.T) {
parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
},
want: &Redirect{
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
Header: make(http.Header),
URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
},
wantLog: `{
"level":"WARN",

View file

@ -218,7 +218,8 @@ type Response struct {
// without custom headers.
func NewResponse(data any) *Response {
return &Response{
Data: data,
Header: make(http.Header),
Data: data,
}
}
@ -242,7 +243,10 @@ type Redirect struct {
}
func NewRedirect(url string) *Redirect {
return &Redirect{URL: url}
return &Redirect{
Header: make(http.Header),
URL: url,
}
}
func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {