simplify KeyProvider interface

This commit is contained in:
Livio Amstutz 2021-06-30 14:10:38 +02:00
parent 0b446618c7
commit 58e27e8073
3 changed files with 9 additions and 9 deletions

View file

@ -1,13 +1,16 @@
package op package op
import ( import (
"context"
"net/http" "net/http"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
type KeyProvider interface { type KeyProvider interface {
Storage() Storage GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
} }
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
@ -17,7 +20,7 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
} }
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.Storage().GetKeySet(r.Context()) keySet, err := k.GetKeySet(r.Context())
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
utils.MarshalJSON(w, err) utils.MarshalJSON(w, err)

View file

@ -74,7 +74,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o)) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
return router return router
} }
@ -281,7 +281,7 @@ func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig
if !ok { if !ok {
return nil, errors.New("invalid kid") return nil, errors.New("invalid kid")
} }
return jws.Verify(key) return jws.Verify(&key)
} }
type Option func(o *openidProvider) error type Option func(o *openidProvider) error

View file

@ -122,13 +122,10 @@ type jwtProfileKeySet struct {
//VerifySignature implements oidc.KeySet by getting the public key from Storage implementation //VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
keyID, alg := oidc.GetKeyIDAndAlg(jws) keyID, _ := oidc.GetKeyIDAndAlg(jws)
key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.userID) key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.userID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching keys: %w", err) return nil, fmt.Errorf("error fetching keys: %w", err)
} }
if key.Algorithm != alg { return jws.Verify(key)
}
return jws.Verify(&key)
} }