simplify KeyProvider interface
This commit is contained in:
parent
0b446618c7
commit
58e27e8073
3 changed files with 9 additions and 9 deletions
|
@ -1,13 +1,16 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/utils"
|
"github.com/caos/oidc/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyProvider interface {
|
type KeyProvider interface {
|
||||||
Storage() Storage
|
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
|
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
|
@ -17,7 +20,7 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||||
keySet, err := k.Storage().GetKeySet(r.Context())
|
keySet, err := k.GetKeySet(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
utils.MarshalJSON(w, err)
|
utils.MarshalJSON(w, err)
|
||||||
|
|
|
@ -74,7 +74,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
|
||||||
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
|
||||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
|
||||||
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
|
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
|
||||||
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o))
|
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,7 +281,7 @@ func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSig
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("invalid kid")
|
return nil, errors.New("invalid kid")
|
||||||
}
|
}
|
||||||
return jws.Verify(key)
|
return jws.Verify(&key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(o *openidProvider) error
|
type Option func(o *openidProvider) error
|
||||||
|
|
|
@ -122,13 +122,10 @@ type jwtProfileKeySet struct {
|
||||||
|
|
||||||
//VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
|
//VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
|
||||||
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
|
||||||
keyID, alg := oidc.GetKeyIDAndAlg(jws)
|
keyID, _ := oidc.GetKeyIDAndAlg(jws)
|
||||||
key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.userID)
|
key, err := k.storage.GetKeyByIDAndUserID(ctx, keyID, k.userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error fetching keys: %w", err)
|
return nil, fmt.Errorf("error fetching keys: %w", err)
|
||||||
}
|
}
|
||||||
if key.Algorithm != alg {
|
return jws.Verify(key)
|
||||||
|
|
||||||
}
|
|
||||||
return jws.Verify(&key)
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue