diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index c5797db..96f9b45 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -116,7 +116,7 @@ func (s *AuthStorage) Health(ctx context.Context) error { return nil } -func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) { +func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} if authReq.CodeChallenge != "" { a.CodeChallenge = &oidc.CodeChallenge{ diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 0e13df2..e01de51 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -11,6 +11,7 @@ import ( "github.com/gorilla/schema" "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/utils" ) @@ -19,13 +20,14 @@ type Authorizer interface { Decoder() *schema.Decoder Encoder() *schema.Encoder Signer() Signer + IDTokenVerifier() rp.Verifier Crypto() Crypto Issuer() string } type ValidationAuthorizer interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) } func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { @@ -44,11 +46,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { if validater, ok := authorizer.(ValidationAuthorizer); ok { validation = validater.ValidateAuthRequest } - if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil { + userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier()) + if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq) + req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return @@ -61,23 +64,17 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { RedirectToLogin(req.GetID(), client, w, r) } -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error { +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { - return err + return "", err } if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { - return err + return "", err } if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil { - return err + return "", err } - // if NeedsExistingSession(authReq) { - // session, err := storage.CheckSession(authReq.IDTokenHint) - // if err != nil { - // return err - // } - // } - return nil + return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier) } func ValidateAuthReqScopes(scopes []string) error { @@ -130,6 +127,17 @@ func ValidateAuthReqResponseType(responseType oidc.ResponseType) error { return nil } +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { + if idTokenHint == "" { + return "", nil + } + claims, err := verifier.Verify(ctx, "", idTokenHint) + if err != nil { + return "", ErrInvalidRequest("id_token_hint invalid") + } + return claims.Subject, nil +} + func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) { login := client.LoginURL(authReqID) http.Redirect(w, r, login, http.StatusFound) diff --git a/pkg/op/storage.go b/pkg/op/storage.go index f213618..4655b88 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -10,7 +10,7 @@ import ( ) type AuthStorage interface { - CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error) + CreateAuthRequest(context.Context, *oidc.AuthRequest, string) (AuthRequest, error) AuthRequestByID(context.Context, string) (AuthRequest, error) DeleteAuthRequest(context.Context, string) error