feat(op): dynamic issuer depending on request / host

BREAKING CHANGE: The OpenID Provider package is now able to handle multiple issuers with a single storage implementation. The issuer will be selected from the host of the request and passed into the context, where every function can read it from if necessary. This results in some fundamental changes:
 - `Configuration` interface:
   - `Issuer() string` has been changed to `IssuerFromRequest(r *http.Request) string`
   - `Insecure() bool` has been added
 - OpenIDProvider interface and dependants:
   - `Issuer` has been removed from Config struct
   - `NewOpenIDProvider` now takes an additional parameter `issuer` and returns a pointer to the public/default implementation and not an OpenIDProvider interface:
     `NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opOpts ...Option) (OpenIDProvider, error)` changed to `NewOpenIDProvider(ctx context.Context, issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error)`
   - therefore the parameter type Option changed to the public type as well: `Option func(o *Provider) error`
   - `AuthCallbackURL(o OpenIDProvider) func(string) string` has been changed to `AuthCallbackURL(o OpenIDProvider) func(context.Context, string) string`
   - `IDTokenHintVerifier() IDTokenHintVerifier` (Authorizer, OpenIDProvider, SessionEnder interfaces), `AccessTokenVerifier() AccessTokenVerifier` (Introspector, OpenIDProvider, Revoker, UserinfoProvider interfaces) and `JWTProfileVerifier() JWTProfileVerifier` (IntrospectorJWTProfile, JWTAuthorizationGrantExchanger, OpenIDProvider, RevokerJWTProfile interfaces) now take a context.Context parameter `IDTokenHintVerifier(context.Context) IDTokenHintVerifier`, `AccessTokenVerifier(context.Context) AccessTokenVerifier` and `JWTProfileVerifier(context.Context) JWTProfileVerifier`
   - `OidcDevMode` (CAOS_OIDC_DEV) environment variable check has been removed, use `WithAllowInsecure()` Option
 - Signing: the signer is not kept in memory anymore, but created on request from the loaded key:
   - `Signer` interface and func `NewSigner` have been removed
   - `ReadySigner(s Signer) ProbesFn` has been removed
   - `CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration` has been changed to `CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration`
   - `Storage` interface:
     - `GetSigningKey(context.Context, chan<- jose.SigningKey)` has been changed to `SigningKey(context.Context) (SigningKey, error)`
     - `KeySet(context.Context) ([]Key, error)` has been added
     - `GetKeySet(context.Context) (*jose.JSONWebKeySet, error)` has been changed to `KeySet(context.Context) ([]Key, error)`
   - `SigAlgorithms(s Signer) []string` has been changed to `SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string`
   - KeyProvider interface: `GetKeySet(context.Context) (*jose.JSONWebKeySet, error)` has been changed to `KeySet(context.Context) ([]Key, error)`
   - `CreateIDToken`: the Signer parameter has been removed
This commit is contained in:
Livio Amstutz 2022-04-19 14:33:07 +02:00
parent 885fe0d45c
commit a27ba09872
No known key found for this signature in database
GPG key ID: 7AB5FDFBCA448635
33 changed files with 1504 additions and 657 deletions

View file

@ -43,9 +43,41 @@ type storage struct {
} }
type signingKey struct { type signingKey struct {
ID string id string
Algorithm string algorithm jose.SignatureAlgorithm
Key *rsa.PrivateKey key *rsa.PrivateKey
}
func (s *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.algorithm
}
func (s *signingKey) Key() interface{} {
return s.key
}
func (s *signingKey) ID() string {
return s.id
}
type publicKey struct {
signingKey
}
func (s *publicKey) ID() string {
return s.id
}
func (s *publicKey) Algorithm() jose.SignatureAlgorithm {
return s.algorithm
}
func (s *publicKey) Use() string {
return "sig"
}
func (s *publicKey) Key() interface{} {
return &s.key.PublicKey
} }
func NewStorage() *storage { func NewStorage() *storage {
@ -78,9 +110,9 @@ func NewStorage() *storage {
}, },
}, },
signingKey: signingKey{ signingKey: signingKey{
ID: "id", id: uuid.NewString(),
Algorithm: "RS256", algorithm: jose.RS256,
Key: key, key: key,
}, },
} }
} }
@ -113,6 +145,8 @@ func (s *storage) CheckUsernamePassword(username, password, id string) error {
//CreateAuthRequest implements the op.Storage interface //CreateAuthRequest implements the op.Storage interface
//it will be called after parsing and validation of the authentication request //it will be called after parsing and validation of the authentication request
func (s *storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) { func (s *storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
headers := op.IssuerFromContext(ctx)
_ = headers
//typically, you'll fill your internal / storage model with the information of the passed object //typically, you'll fill your internal / storage model with the information of the passed object
request := authRequestToInternal(authReq, userID) request := authRequestToInternal(authReq, userID)
@ -278,39 +312,29 @@ func (s *storage) RevokeToken(ctx context.Context, token string, userID string,
return nil return nil
} }
//GetSigningKey implements the op.Storage interface //SigningKey implements the op.Storage interface
//it will be called when creating the OpenID Provider //it will be called when creating the OpenID Provider
func (s *storage) GetSigningKey(ctx context.Context, keyCh chan<- jose.SigningKey) { func (s *storage) SigningKey(ctx context.Context) (op.SigningKey, error) {
//in this example the signing key is a static rsa.PrivateKey and the algorithm used is RS256 //in this example the signing key is a static rsa.PrivateKey and the algorithm used is RS256
//you would obviously have a more complex implementation and store / retrieve the key from your database as well //you would obviously have a more complex implementation and store / retrieve the key from your database as well
// return &s.signingKey, nil
//the idea of the signing key channel is, that you can (with what ever mechanism) rotate your signing key and
//switch the key of the signer via this channel
keyCh <- jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(s.signingKey.Algorithm), //always tell the signer with algorithm to use
Key: jose.JSONWebKey{
KeyID: s.signingKey.ID, //always give the key an id so, that it will include it in the token header as `kid` claim
Key: s.signingKey.Key,
},
}
} }
//GetKeySet implements the op.Storage interface //SignatureAlgorithms implements the op.Storage interface
//it will be called to get the sign
func (s *storage) SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error) {
return []jose.SignatureAlgorithm{s.signingKey.algorithm}, nil
}
//KeySet implements the op.Storage interface
//it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ... //it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ...
func (s *storage) GetKeySet(ctx context.Context) (*jose.JSONWebKeySet, error) { func (s *storage) KeySet(ctx context.Context) ([]op.Key, error) {
//as mentioned above, this example only has a single signing key without key rotation, //as mentioned above, this example only has a single signing key without key rotation,
//so it will directly use its public key //so it will directly use its public key
// //
//when using key rotation you typically would store the public keys alongside the private keys in your database //when using key rotation you typically would store the public keys alongside the private keys in your database
//and give both of them an expiration date, with the public key having a longer lifetime (e.g. rotate private key every //and give both of them an expiration date, with the public key having a longer lifetime
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ return []op.Key{&publicKey{s.signingKey}}, nil
{
KeyID: s.signingKey.ID,
Algorithm: s.signingKey.Algorithm,
Use: oidc.KeyUseSignature,
Key: &s.signingKey.Key.PublicKey,
}},
}, nil
} }
//GetClientByClientID implements the op.Storage interface //GetClientByClientID implements the op.Storage interface

View file

@ -1,11 +1,14 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/caos/oidc/pkg/op"
) )
const ( const (
@ -46,22 +49,22 @@ var (
type login struct { type login struct {
authenticate authenticate authenticate authenticate
router *mux.Router router *mux.Router
callback func(string) string callback func(context.Context, string) string
} }
func NewLogin(authenticate authenticate, callback func(string) string) *login { func NewLogin(authenticate authenticate, callback func(context.Context, string) string, issuerInterceptor *op.IssuerInterceptor) *login {
l := &login{ l := &login{
authenticate: authenticate, authenticate: authenticate,
callback: callback, callback: callback,
} }
l.createRouter() l.createRouter(issuerInterceptor)
return l return l
} }
func (l *login) createRouter() { func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
l.router = mux.NewRouter() l.router = mux.NewRouter()
l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler)
l.router.Path("/username").Methods("POST").HandlerFunc(l.checkLoginHandler) l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler))
} }
type authenticate interface { type authenticate interface {
@ -111,5 +114,5 @@ func (l *login) checkLoginHandler(w http.ResponseWriter, r *http.Request) {
renderLogin(w, id, err) renderLogin(w, id, err)
return return
} }
http.Redirect(w, r, l.callback(id), http.StatusFound) http.Redirect(w, r, l.callback(r.Context(), id), http.StatusFound)
} }

View file

@ -3,10 +3,8 @@ package main
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"fmt"
"log" "log"
"net/http" "net/http"
"os"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"golang.org/x/text/language" "golang.org/x/text/language"
@ -30,9 +28,6 @@ func init() {
func main() { func main() {
ctx := context.Background() ctx := context.Background()
//this will allow us to use an issuer with http:// instead of https://
os.Setenv(op.OidcDevMode, "true")
port := "9998" port := "9998"
//the OpenID Provider requires a 32-byte key for (token) encryption //the OpenID Provider requires a 32-byte key for (token) encryption
@ -62,7 +57,7 @@ func main() {
//the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process //the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process
//for the simplicity of the example this means a simple page with username and password field //for the simplicity of the example this means a simple page with username and password field
l := NewLogin(storage, op.AuthCallbackURL(provider)) l := NewLogin(storage, op.AuthCallbackURL(provider), op.NewIssuerInterceptor(provider.IssuerFromRequest))
//regardless of how many pages / steps there are in the process, the UI must be registered in the router, //regardless of how many pages / steps there are in the process, the UI must be registered in the router,
//so we will direct all calls to /login to the login UI //so we will direct all calls to /login to the login UI
@ -72,7 +67,8 @@ func main() {
//is served on the correct path //is served on the correct path
// //
//if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
//then you would have to set the path prefix (/custom/path/) //then you would have to set the path prefix (/custom/path/):
//router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler()))
router.PathPrefix("/").Handler(provider.HttpHandler()) router.PathPrefix("/").Handler(provider.HttpHandler())
server := &http.Server{ server := &http.Server{
@ -89,9 +85,8 @@ func main() {
//newOP will create an OpenID Provider for localhost on a specified port with a given encryption key //newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
//and a predefined default logout uri //and a predefined default logout uri
//it will enable all options (see descriptions) //it will enable all options (see descriptions)
func newOP(ctx context.Context, storage op.Storage, port string, key [32]byte) (op.OpenIDProvider, error) { func newOP(ctx context.Context, storage op.Storage, port string, key [32]byte) (*op.Provider, error) {
config := &op.Config{ config := &op.Config{
Issuer: fmt.Sprintf("http://localhost:%s/", port),
CryptoKey: key, CryptoKey: key,
//will be used if the end_session endpoint is called without a post_logout_redirect_uri //will be used if the end_session endpoint is called without a post_logout_redirect_uri
@ -115,7 +110,10 @@ func newOP(ctx context.Context, storage op.Storage, port string, key [32]byte) (
//this example has only static texts (in English), so we'll set the here accordingly //this example has only static texts (in English), so we'll set the here accordingly
SupportedUILocales: []language.Tag{language.English}, SupportedUILocales: []language.Tag{language.English},
} }
handler, err := op.NewOpenIDProvider(ctx, config, storage, //handler, err := op.NewOpenIDProvider(ctx, fmt.Sprintf("http://localhost:%s/", port), config, storage,
handler, err := op.NewDynamicOpenIDProvider(ctx, "/", config, storage,
//we must explicitly allow the use of the http issuer
op.WithAllowInsecure(),
//as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth //as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
) )

3
go.mod
View file

@ -3,7 +3,6 @@ module github.com/caos/oidc
go 1.15 go 1.15
require ( require (
github.com/caos/logging v0.3.1
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.2 // indirect github.com/google/go-cmp v0.5.2 // indirect
github.com/google/go-github/v31 v31.0.0 github.com/google/go-github/v31 v31.0.0
@ -16,7 +15,9 @@ require (
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1
golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
golang.org/x/text v0.3.7 golang.org/x/text v0.3.7
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )

5
go.sum
View file

@ -33,8 +33,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/caos/logging v0.3.1 h1:892AMeHs09D0e3ZcGB+QDRsZ5+2xtPAsAhOy8eKfztc=
github.com/caos/logging v0.3.1/go.mod h1:B8QNS0WDmR2Keac52Fw+XN4ZJkzLDGrcRIPB2Ux4uRo=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
@ -138,7 +136,6 @@ github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@ -402,8 +399,6 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -37,10 +37,8 @@ type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
Signer() Signer IDTokenHintVerifier(context.Context) IDTokenHintVerifier
IDTokenHintVerifier() IDTokenHintVerifier
Crypto() Crypto Crypto() Crypto
Issuer() string
RequestObjectSupported() bool RequestObjectSupported() bool
} }
@ -71,8 +69,9 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
ctx := r.Context()
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer()) authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
@ -82,7 +81,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if validater, ok := authorizer.(AuthorizeValidator); ok { if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest validation = validater.ValidateAuthRequest
} }
userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier()) userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
@ -91,12 +90,12 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
return return
} }
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return return
} }
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return return

View file

@ -2,16 +2,24 @@ package op
import ( import (
"errors" "errors"
"net/http"
"net/url" "net/url"
"os" "strings"
"golang.org/x/text/language" "golang.org/x/text/language"
) )
const OidcDevMode = "CAOS_OIDC_DEV" var (
ErrInvalidIssuerPath = errors.New("no fragments or query allowed for issuer")
ErrInvalidIssuerNoIssuer = errors.New("missing issuer")
ErrInvalidIssuerURL = errors.New("invalid url for issuer")
ErrInvalidIssuerMissingHost = errors.New("host for issuer missing")
ErrInvalidIssuerHTTPS = errors.New("scheme for issuer must be `https`")
)
type Configuration interface { type Configuration interface {
Issuer() string IssuerFromRequest(r *http.Request) string
Insecure() bool
AuthorizationEndpoint() Endpoint AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint IntrospectionEndpoint() Endpoint
@ -37,32 +45,74 @@ type Configuration interface {
SupportedUILocales() []language.Tag SupportedUILocales() []language.Tag
} }
func ValidateIssuer(issuer string) error { type IssuerFromRequest func(r *http.Request) string
func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
issuerPath, err := url.Parse(path)
if err != nil {
return nil, ErrInvalidIssuerURL
}
if err := ValidateIssuerPath(issuerPath); err != nil {
return nil, err
}
return func(r *http.Request) string {
return dynamicIssuer(r.Host, path, allowInsecure)
}, nil
}
}
func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
if err := ValidateIssuer(issuer, allowInsecure); err != nil {
return nil, err
}
return func(_ *http.Request) string {
return issuer
}, nil
}
}
func ValidateIssuer(issuer string, allowInsecure bool) error {
if issuer == "" { if issuer == "" {
return errors.New("missing issuer") return ErrInvalidIssuerNoIssuer
} }
u, err := url.Parse(issuer) u, err := url.Parse(issuer)
if err != nil { if err != nil {
return errors.New("invalid url for issuer") return ErrInvalidIssuerURL
} }
if u.Host == "" { if u.Host == "" {
return errors.New("host for issuer missing") return ErrInvalidIssuerMissingHost
} }
if u.Scheme != "https" { if u.Scheme != "https" {
if !devLocalAllowed(u) { if !devLocalAllowed(u, allowInsecure) {
return errors.New("scheme for issuer must be `https`") return ErrInvalidIssuerHTTPS
} }
} }
if u.Fragment != "" || len(u.Query()) > 0 { return ValidateIssuerPath(u)
return errors.New("no fragments or query allowed for issuer") }
func ValidateIssuerPath(issuer *url.URL) error {
if issuer.Fragment != "" || len(issuer.Query()) > 0 {
return ErrInvalidIssuerPath
} }
return nil return nil
} }
func devLocalAllowed(url *url.URL) bool { func devLocalAllowed(url *url.URL, allowInsecure bool) bool {
_, b := os.LookupEnv(OidcDevMode) if !allowInsecure {
if !b { return false
return b
} }
return url.Scheme == "http" return url.Scheme == "http"
} }
func dynamicIssuer(issuer, path string, allowInsecure bool) string {
schema := "https"
if allowInsecure {
schema = "http"
}
if len(path) > 0 && !strings.HasPrefix(path, "/") {
path = "/" + path
}
return schema + "://" + issuer + path
}

View file

@ -1,13 +1,17 @@
package op package op
import ( import (
"os" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestValidateIssuer(t *testing.T) { func TestValidateIssuer(t *testing.T) {
type args struct { type args struct {
issuer string issuer string
allowInsecure bool
} }
tests := []struct { tests := []struct {
name string name string
@ -16,65 +20,97 @@ func TestValidateIssuer(t *testing.T) {
}{ }{
{ {
"missing issuer fails", "missing issuer fails",
args{""}, args{
issuer: "",
},
true, true,
}, },
{ {
"invalid url for issuer fails", "invalid url for issuer fails",
args{":issuer"}, args{
true, issuer: ":issuer",
}, },
{
"invalid url for issuer fails",
args{":issuer"},
true, true,
}, },
{ {
"host for issuer missing fails", "host for issuer missing fails",
args{"https:///issuer"}, args{
true, issuer: "https:///issuer",
}, },
{
"host for not https fails",
args{"http://issuer.com"},
true, true,
}, },
{ {
"host with fragment fails", "host with fragment fails",
args{"https://issuer.com/#issuer"}, args{
issuer: "https://issuer.com/#issuer",
},
true, true,
}, },
{ {
"host with query fails", "host with query fails",
args{"https://issuer.com?issuer=me"}, args{
issuer: "https://issuer.com?issuer=me",
},
true,
},
{
"host with http fails",
args{
issuer: "http://issuer.com",
},
true, true,
}, },
{ {
"host with https ok", "host with https ok",
args{"https://issuer.com"}, args{
issuer: "https://issuer.com",
},
false, false,
}, },
{ {
"localhost with http fails", "custom scheme fails",
args{"http://localhost:9999"}, args{
issuer: "custom://localhost:9999",
},
true,
},
{
"http with allowInsecure ok",
args{
issuer: "http://localhost:9999",
allowInsecure: true,
},
false,
},
{
"https with allowInsecure ok",
args{
issuer: "https://localhost:9999",
allowInsecure: true,
},
false,
},
{
"custom scheme with allowInsecure fails",
args{
issuer: "custom://localhost:9999",
allowInsecure: true,
},
true, true,
}, },
} }
//ensure env is not set
//nolint:errcheck
os.Unsetenv(OidcDevMode)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { if err := ValidateIssuer(tt.args.issuer, tt.args.allowInsecure); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
} }
} }
func TestValidateIssuerDevLocalAllowed(t *testing.T) { func TestValidateIssuerPath(t *testing.T) {
type args struct { type args struct {
issuer string issuerPath *url.URL
} }
tests := []struct { tests := []struct {
name string name string
@ -82,17 +118,217 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
"localhost with http with dev ok", "empty ok",
args{"http://localhost:9999"}, args{func() *url.URL {
u, _ := url.Parse("")
return u
}()},
false, false,
}, },
{
"custom ok",
args{func() *url.URL {
u, _ := url.Parse("/custom")
return u
}()},
false,
},
{
"fragment fails",
args{func() *url.URL {
u, _ := url.Parse("#fragment")
return u
}()},
true,
},
{
"query fails",
args{func() *url.URL {
u, _ := url.Parse("?query=value")
return u
}()},
true,
},
} }
//nolint:errcheck
os.Setenv(OidcDevMode, "true")
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { if err := ValidateIssuerPath(tt.args.issuerPath); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ValidateIssuerPath() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestIssuerFromHost(t *testing.T) {
type args struct {
path string
allowInsecure bool
target string
}
type res struct {
issuer string
err error
}
tests := []struct {
name string
args args
res res
}{
{
"invalid issuer path",
args{
path: "/#fragment",
allowInsecure: false,
},
res{
issuer: "",
err: ErrInvalidIssuerPath,
},
},
{
"empty path secure",
args{
path: "",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com",
err: nil,
},
},
{
"custom path secure",
args{
path: "/custom/",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"custom path no leading slash",
args{
path: "custom/",
allowInsecure: false,
target: "https://issuer.com",
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"empty path unsecure",
args{
path: "",
allowInsecure: true,
target: "http://issuer.com",
},
res{
issuer: "http://issuer.com",
err: nil,
},
},
{
"custom path unsecure",
args{
path: "/custom/",
allowInsecure: true,
target: "http://issuer.com",
},
res{
issuer: "http://issuer.com/custom/",
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
issuer, err := IssuerFromHost(tt.args.path)(tt.args.allowInsecure)
if tt.res.err == nil {
assert.NoError(t, err)
req := httptest.NewRequest("", tt.args.target, nil)
assert.Equal(t, tt.res.issuer, issuer(req))
}
if tt.res.err != nil {
assert.ErrorIs(t, err, tt.res.err)
}
})
}
}
func TestStaticIssuer(t *testing.T) {
type args struct {
issuer string
allowInsecure bool
}
type res struct {
issuer string
err error
}
tests := []struct {
name string
args args
res res
}{
{
"invalid issuer",
args{
issuer: "",
allowInsecure: false,
},
res{
issuer: "",
err: ErrInvalidIssuerNoIssuer,
},
},
{
"empty path secure",
args{
issuer: "https://issuer.com",
allowInsecure: false,
},
res{
issuer: "https://issuer.com",
err: nil,
},
},
{
"custom path secure",
args{
issuer: "https://issuer.com/custom/",
allowInsecure: false,
},
res{
issuer: "https://issuer.com/custom/",
err: nil,
},
},
{
"unsecure",
args{
issuer: "http://issuer.com",
allowInsecure: true,
},
res{
issuer: "http://issuer.com",
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
issuer, err := StaticIssuer(tt.args.issuer)(tt.args.allowInsecure)
if tt.res.err == nil {
assert.NoError(t, err)
assert.Equal(t, tt.res.issuer, issuer(nil))
}
if tt.res.err != nil {
assert.ErrorIs(t, err, tt.res.err)
} }
}) })
} }

49
pkg/op/context.go Normal file
View file

@ -0,0 +1,49 @@
package op
import (
"context"
"net/http"
)
type key int
var (
issuer key = 0
)
type IssuerInterceptor struct {
issuerFromRequest IssuerFromRequest
}
//NewIssuerInterceptor will set the issuer into the context
//by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
return &IssuerInterceptor{
issuerFromRequest: issuerFromRequest,
}
}
func (i *IssuerInterceptor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
})
}
func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
}
}
//IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
//it will return an empty string if not found
func IssuerFromContext(ctx context.Context) string {
ctxIssuer, _ := ctx.Value(issuer).(string)
return ctxIssuer
}
func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
ctx := context.WithValue(r.Context(), issuer, i.issuerFromRequest(r))
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}

76
pkg/op/context_test.go Normal file
View file

@ -0,0 +1,76 @@
package op
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIssuerInterceptor(t *testing.T) {
type fields struct {
issuerFromRequest IssuerFromRequest
}
type args struct {
r *http.Request
next http.Handler
}
type res struct {
issuer string
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"empty",
fields{
func(r *http.Request) string {
return ""
},
},
args{},
res{
issuer: "",
},
},
{
"static",
fields{
func(r *http.Request) string {
return "static"
},
},
args{},
res{
issuer: "static",
},
},
{
"host",
fields{
func(r *http.Request) string {
return r.Host
},
},
args{},
res{
issuer: "issuer.com",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := NewIssuerInterceptor(tt.fields.issuerFromRequest)
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
assert.Equal(t, tt.res.issuer, IssuerFromContext(r.Context()))
})
req := httptest.NewRequest("", "https://issuer.com", nil)
i.Handler(next).ServeHTTP(nil, req)
i.HandlerFunc(next).ServeHTTP(nil, req)
})
}
}

View file

@ -1,49 +1,17 @@
package op package op
import ( import (
"context"
"net/http" "net/http"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/caos/oidc/pkg/http" httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
) )
func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) { type DiscoverStorage interface {
return func(w http.ResponseWriter, r *http.Request) { SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
Discover(w, CreateDiscoveryConfig(c, s))
}
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()),
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c),
ResponseTypesSupported: ResponseTypes(c),
GrantTypesSupported: GrantTypes(c),
SubjectTypesSupported: SubjectTypes(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c),
ClaimsSupported: SupportedClaims(c),
CodeChallengeMethodsSupported: CodeChallengeMethods(c),
UILocalesSupported: c.SupportedUILocales(),
RequestParameterSupported: c.RequestObjectSupported(),
}
} }
var DefaultSupportedScopes = []string{ var DefaultSupportedScopes = []string{
@ -55,6 +23,46 @@ var DefaultSupportedScopes = []string{
oidc.ScopeOfflineAccess, oidc.ScopeOfflineAccess,
} }
func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Discover(w, CreateDiscoveryConfig(r, c, s))
}
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
issuer := config.IssuerFromRequest(r)
return &oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
TokenEndpoint: config.TokenEndpoint().Absolute(issuer),
IntrospectionEndpoint: config.IntrospectionEndpoint().Absolute(issuer),
UserinfoEndpoint: config.UserinfoEndpoint().Absolute(issuer),
RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
SubjectTypesSupported: SubjectTypes(config),
IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
ClaimsSupported: SupportedClaims(config),
CodeChallengeMethodsSupported: CodeChallengeMethods(config),
UILocalesSupported: config.SupportedUILocales(),
RequestParameterSupported: config.RequestObjectSupported(),
}
}
func Scopes(c Configuration) []string { func Scopes(c Configuration) []string {
return DefaultSupportedScopes //TODO: config return DefaultSupportedScopes //TODO: config
} }
@ -84,6 +92,88 @@ func GrantTypes(c Configuration) []oidc.GrantType {
return grantTypes return grantTypes
} }
func SubjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string {
algorithms, err := storage.SignatureAlgorithms(ctx)
if err != nil {
return nil
}
algs := make([]string, len(algorithms))
for i, algorithm := range algorithms {
algs[i] = string(algorithm)
}
return algs
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic,
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func SupportedClaims(c Configuration) []string { func SupportedClaims(c Configuration) []string {
return []string{ //TODO: config return []string{ //TODO: config
"sub", "sub",
@ -113,59 +203,6 @@ func SupportedClaims(c Configuration) []string {
} }
} }
func SigAlgorithms(s Signer) []string {
return []string{string(s.SignatureAlgorithm())}
}
func SubjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic,
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
codeMethods := make([]oidc.CodeChallengeMethod, 0, 1) codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
if c.CodeMethodS256Supported() { if c.CodeMethodS256Supported() {
@ -173,24 +210,3 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
} }
return codeMethods return codeMethods
} }
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}

View file

@ -1,12 +1,13 @@
package op_test package op_test
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -47,8 +48,9 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) {
type args struct { type args struct {
c op.Configuration request *http.Request
s op.Signer c op.Configuration
s op.DiscoverStorage
} }
tests := []struct { tests := []struct {
name string name string
@ -59,9 +61,8 @@ func TestCreateDiscoveryConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) { got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s)
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
}
}) })
} }
} }
@ -83,9 +84,8 @@ func Test_scopes(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) { got := op.Scopes(tt.args.c)
t.Errorf("scopes() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
}
}) })
} }
} }
@ -99,13 +99,16 @@ func Test_ResponseTypes(t *testing.T) {
args args args args
want []string want []string
}{ }{
// TODO: Add test cases. {
"code and implicit flow",
args{},
[]string{"code", "id_token", "id_token token"},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { got := op.ResponseTypes(tt.args.c)
t.Errorf("responseTypes() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
}
}) })
} }
} }
@ -117,63 +120,48 @@ func Test_GrantTypes(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want []string want []oidc.GrantType
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
func TestSupportedClaims(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
}
})
}
}
func Test_SigAlgorithms(t *testing.T) {
m := mock.NewMockSigner(gomock.NewController(t))
type args struct {
s op.Signer
}
tests := []struct {
name string
args args
want []string
}{ }{
{ {
"", "code and implicit flow",
args{func() op.Signer { args{
m.EXPECT().SignatureAlgorithm().Return(jose.RS256) func() op.Configuration {
return m c := mock.NewMockConfiguration(gomock.NewController(t))
}()}, c.EXPECT().GrantTypeRefreshTokenSupported().Return(false)
[]string{"RS256"}, c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
return c
}(),
},
[]oidc.GrantType{
oidc.GrantTypeCode,
oidc.GrantTypeImplicit,
},
},
{
"code, implicit flow, refresh token, token exchange, jwt profile",
args{
func() op.Configuration {
c := mock.NewMockConfiguration(gomock.NewController(t))
c.EXPECT().GrantTypeRefreshTokenSupported().Return(true)
c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
return c
}(),
},
[]oidc.GrantType{
oidc.GrantTypeCode,
oidc.GrantTypeImplicit,
oidc.GrantTypeRefreshToken,
oidc.GrantTypeTokenExchange,
oidc.GrantTypeBearer,
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) { got := op.GrantTypes(tt.args.c)
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
}
}) })
} }
} }
@ -195,9 +183,80 @@ func Test_SubjectTypes(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { got := op.SubjectTypes(tt.args.c)
t.Errorf("subjectTypes() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
} })
}
}
func Test_SigAlgorithms(t *testing.T) {
m := mock.NewMockDiscoverStorage(gomock.NewController(t))
type args struct {
s op.DiscoverStorage
}
tests := []struct {
name string
args args
want []string
}{
{
"",
args{func() op.DiscoverStorage {
m.EXPECT().SignatureAlgorithms(gomock.Any()).Return([]jose.SignatureAlgorithm{jose.RS256}, nil)
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.SigAlgorithms(context.Background(), tt.args.s)
assert.Equal(t, tt.want, got)
})
}
}
func Test_RequestObjectSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(true)
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().RequestObjectSupported().Return(true)
m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.RequestObjectSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
}) })
} }
} }
@ -244,9 +303,311 @@ func Test_AuthMethodsTokenEndpoint(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.AuthMethodsTokenEndpoint(tt.args.c); !reflect.DeepEqual(got, tt.want) { got := op.AuthMethodsTokenEndpoint(tt.args.c)
t.Errorf("authMethods() = %v, want %v", got, tt.want) assert.Equal(t, tt.want, got)
} })
}
}
func Test_TokenSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.TokenSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_IntrospectionSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.IntrospectionSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_AuthMethodsIntrospectionEndpoint(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.AuthMethod
}{
{
"basic only",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodBasic},
},
{
"basic and private_key_jwt",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodBasic, oidc.AuthMethodPrivateKeyJWT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.AuthMethodsIntrospectionEndpoint(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_RevocationSigAlgorithms(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"not supported, empty",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
nil,
},
{
"supported, empty",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return(nil)
return m
}()},
nil,
},
{
"supported, list",
args{func() op.Configuration {
m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.RevocationSigAlgorithms(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_AuthMethodsRevocationEndpoint(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.AuthMethod
}{
{
"none and basic",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(false)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic},
},
{
"none, basic and post",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(true)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost},
},
{
"none, basic, post and private_key_jwt",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().AuthMethodPostSupported().Return(true)
m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
return m
}()},
[]oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost, oidc.AuthMethodPrivateKeyJWT},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.AuthMethodsRevocationEndpoint(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func TestSupportedClaims(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"scopes",
args{},
[]string{
"sub",
"aud",
"exp",
"iat",
"iss",
"auth_time",
"nonce",
"acr",
"amr",
"c_hash",
"at_hash",
"act",
"scopes",
"client_id",
"azp",
"preferred_username",
"name",
"family_name",
"given_name",
"locale",
"email",
"email_verified",
"phone_number",
"phone_number_verified",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.SupportedClaims(tt.args.c)
assert.Equal(t, tt.want, got)
})
}
}
func Test_CodeChallengeMethods(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []oidc.CodeChallengeMethod
}{
{
"not supported",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().CodeMethodS256Supported().Return(false)
return m
}()},
[]oidc.CodeChallengeMethod{},
},
{
"S256",
args{func() op.Configuration {
m := mock.NewMockConfiguration(gomock.NewController(t))
m.EXPECT().CodeMethodS256Supported().Return(true)
return m
}()},
[]oidc.CodeChallengeMethod{oidc.CodeChallengeMethodS256},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := op.CodeChallengeMethods(tt.args.c)
assert.Equal(t, tt.want, got)
}) })
} }
} }

View file

@ -10,7 +10,7 @@ import (
) )
type KeyProvider interface { type KeyProvider interface {
GetKeySet(context.Context) (*jose.JSONWebKeySet, error) KeySet(context.Context) ([]Key, error)
} }
func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
@ -20,10 +20,23 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
} }
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.GetKeySet(r.Context()) keySet, err := k.KeySet(r.Context())
if err != nil { if err != nil {
httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError) httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
return return
} }
httphelper.MarshalJSON(w, keySet) httphelper.MarshalJSON(w, jsonWebKeySet(keySet))
}
func jsonWebKeySet(keys []Key) *jose.JSONWebKeySet {
webKeys := make([]jose.JSONWebKey, len(keys))
for i, key := range keys {
webKeys[i] = jose.JSONWebKey{
KeyID: key.ID(),
Algorithm: string(key.Algorithm()),
Use: key.Use(),
Key: key.Key(),
}
}
return &jose.JSONWebKeySet{Keys: webKeys}
} }

View file

@ -35,7 +35,7 @@ func TestKeys(t *testing.T) {
args: args{ args: args{
k: func() op.KeyProvider { k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t)) m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError()) m.EXPECT().KeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
return m return m
}(), }(),
}, },
@ -51,39 +51,39 @@ func TestKeys(t *testing.T) {
args: args{ args: args{
k: func() op.KeyProvider { k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t)) m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil) m.EXPECT().KeySet(gomock.Any()).Return(nil, nil)
return m return m
}(), }(),
}, },
res: res{ res: res{
statusCode: http.StatusOK, statusCode: http.StatusOK,
contentType: "application/json", contentType: "application/json",
body: `{"keys":[]}
`,
}, },
}, },
{ {
name: "list", name: "list",
args: args{ args: args{
k: func() op.KeyProvider { k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t)) ctrl := gomock.NewController(t)
m.EXPECT().GetKeySet(gomock.Any()).Return( m := mock.NewMockKeyProvider(ctrl)
&jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ k := mock.NewMockKey(ctrl)
{ k.EXPECT().Key().Return(&rsa.PublicKey{
Key: &rsa.PublicKey{ N: big.NewInt(1),
N: big.NewInt(1), E: 1,
E: 1, })
}, k.EXPECT().ID().Return("id")
KeyID: "id", k.EXPECT().Algorithm().Return(jose.RS256)
}, k.EXPECT().Use().Return("sig")
}}, m.EXPECT().KeySet(gomock.Any()).Return([]op.Key{k}, nil)
nil,
)
return m return m
}(), }(),
}, },
res: res{ res: res{
statusCode: http.StatusOK, statusCode: http.StatusOK,
contentType: "application/json", contentType: "application/json",
body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]} body: `{"keys":[{"use":"sig","kty":"RSA","kid":"id","alg":"RS256","n":"AQ","e":"AQ"}]}
`, `,
}, },
}, },

View file

@ -5,6 +5,7 @@
package mock package mock
import ( import (
context "context"
reflect "reflect" reflect "reflect"
http "github.com/caos/oidc/pkg/http" http "github.com/caos/oidc/pkg/http"
@ -78,31 +79,17 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
} }
// IDTokenHintVerifier mocks base method. // IDTokenHintVerifier mocks base method.
func (m *MockAuthorizer) IDTokenHintVerifier() op.IDTokenHintVerifier { func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenHintVerifier") ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
ret0, _ := ret[0].(op.IDTokenHintVerifier) ret0, _ := ret[0].(op.IDTokenHintVerifier)
return ret0 return ret0
} }
// IDTokenHintVerifier indicates an expected call of IDTokenHintVerifier. // IDTokenHintVerifier indicates an expected call of IDTokenHintVerifier.
func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier() *gomock.Call { func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
}
// Issuer mocks base method.
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer.
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
} }
// RequestObjectSupported mocks base method. // RequestObjectSupported mocks base method.
@ -119,20 +106,6 @@ func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
} }
// Signer mocks base method.
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(op.Signer)
return ret0
}
// Signer indicates an expected call of Signer.
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
}
// Storage mocks base method. // Storage mocks base method.
func (m *MockAuthorizer) Storage() op.Storage { func (m *MockAuthorizer) Storage() op.Storage {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -20,7 +20,7 @@ func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
m := NewAuthorizer(t) m := NewAuthorizer(t)
ExpectDecoder(m) ExpectDecoder(m)
ExpectEncoder(m) ExpectEncoder(m)
ExpectSigner(m, t) //ExpectSigner(m, t)
ExpectStorage(m, t) ExpectStorage(m, t)
ExpectVerifier(m, t) ExpectVerifier(m, t)
// ExpectErrorHandler(m, t, wantErr) // ExpectErrorHandler(m, t, wantErr)
@ -47,17 +47,18 @@ func ExpectEncoder(a op.Authorizer) {
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder()) mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
} }
func ExpectSigner(a op.Authorizer, t *testing.T) { //
mockA := a.(*MockAuthorizer) //func ExpectSigner(a op.Authorizer, t *testing.T) {
mockA.EXPECT().Signer().DoAndReturn( // mockA := a.(*MockAuthorizer)
func() op.Signer { // mockA.EXPECT().Signer().DoAndReturn(
return &Sig{} // func() op.Signer {
}) // return &Sig{}
} // })
//}
func ExpectVerifier(a op.Authorizer, t *testing.T) { func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer) mockA := a.(*MockAuthorizer)
mockA.EXPECT().IDTokenHintVerifier().DoAndReturn( mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
func() op.IDTokenHintVerifier { func() op.IDTokenHintVerifier {
return op.NewIDTokenHintVerifier("", nil) return op.NewIDTokenHintVerifier("", nil)
}) })

View file

@ -5,6 +5,7 @@
package mock package mock
import ( import (
http "net/http"
reflect "reflect" reflect "reflect"
op "github.com/caos/oidc/pkg/op" op "github.com/caos/oidc/pkg/op"
@ -147,6 +148,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
} }
// Insecure mocks base method.
func (m *MockConfiguration) Insecure() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Insecure")
ret0, _ := ret[0].(bool)
return ret0
}
// Insecure indicates an expected call of Insecure.
func (mr *MockConfigurationMockRecorder) Insecure() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insecure", reflect.TypeOf((*MockConfiguration)(nil).Insecure))
}
// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method. // IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -189,18 +204,18 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsS
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
} }
// Issuer mocks base method. // IssuerFromRequest mocks base method.
func (m *MockConfiguration) Issuer() string { func (m *MockConfiguration) IssuerFromRequest(arg0 *http.Request) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer") ret := m.ctrl.Call(m, "IssuerFromRequest", arg0)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)
return ret0 return ret0
} }
// Issuer indicates an expected call of Issuer. // IssuerFromRequest indicates an expected call of IssuerFromRequest.
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call { func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssuerFromRequest", reflect.TypeOf((*MockConfiguration)(nil).IssuerFromRequest), arg0)
} }
// KeysEndpoint mocks base method. // KeysEndpoint mocks base method.

View file

@ -4,5 +4,6 @@ package mock
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer //go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client //go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration //go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer //go:generate mockgen -package mock -destination ./discovery.mock.go github.com/caos/oidc/pkg/op DiscoverStorage
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op SigningKey,Key
//go:generate mockgen -package mock -destination ./key.mock.go github.com/caos/oidc/pkg/op KeyProvider //go:generate mockgen -package mock -destination ./key.mock.go github.com/caos/oidc/pkg/op KeyProvider

View file

@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
) )
// MockKeyProvider is a mock of KeyProvider interface. // MockKeyProvider is a mock of KeyProvider interface.
@ -35,17 +35,17 @@ func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
return m.recorder return m.recorder
} }
// GetKeySet mocks base method. // KeySet mocks base method.
func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) { func (m *MockKeyProvider) KeySet(arg0 context.Context) ([]op.Key, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0) ret := m.ctrl.Call(m, "KeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet) ret0, _ := ret[0].([]op.Key)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetKeySet indicates an expected call of GetKeySet. // KeySet indicates an expected call of KeySet.
func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { func (mr *MockKeyProviderMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockKeyProvider)(nil).KeySet), arg0)
} }

View file

@ -1,56 +1,55 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Signer) // Source: github.com/caos/oidc/pkg/op (interfaces: SigningKey,Key)
// Package mock is a generated GoMock package. // Package mock is a generated GoMock package.
package mock package mock
import ( import (
context "context"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
) )
// MockSigner is a mock of Signer interface. // MockSigningKey is a mock of SigningKey interface.
type MockSigner struct { type MockSigningKey struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockSignerMockRecorder recorder *MockSigningKeyMockRecorder
} }
// MockSignerMockRecorder is the mock recorder for MockSigner. // MockSigningKeyMockRecorder is the mock recorder for MockSigningKey.
type MockSignerMockRecorder struct { type MockSigningKeyMockRecorder struct {
mock *MockSigner mock *MockSigningKey
} }
// NewMockSigner creates a new mock instance. // NewMockSigningKey creates a new mock instance.
func NewMockSigner(ctrl *gomock.Controller) *MockSigner { func NewMockSigningKey(ctrl *gomock.Controller) *MockSigningKey {
mock := &MockSigner{ctrl: ctrl} mock := &MockSigningKey{ctrl: ctrl}
mock.recorder = &MockSignerMockRecorder{mock} mock.recorder = &MockSigningKeyMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSigner) EXPECT() *MockSignerMockRecorder { func (m *MockSigningKey) EXPECT() *MockSigningKeyMockRecorder {
return m.recorder return m.recorder
} }
// Health mocks base method. // Key mocks base method.
func (m *MockSigner) Health(arg0 context.Context) error { func (m *MockSigningKey) Key() interface{} {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0) ret := m.ctrl.Call(m, "Key")
ret0, _ := ret[0].(error) ret0, _ := ret[0].(interface{})
return ret0 return ret0
} }
// Health indicates an expected call of Health. // Key indicates an expected call of Key.
func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call { func (mr *MockSigningKeyMockRecorder) Key() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockSigningKey)(nil).Key))
} }
// SignatureAlgorithm mocks base method. // SignatureAlgorithm mocks base method.
func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm { func (m *MockSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm") ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(jose.SignatureAlgorithm) ret0, _ := ret[0].(jose.SignatureAlgorithm)
@ -58,21 +57,86 @@ func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
} }
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm. // SignatureAlgorithm indicates an expected call of SignatureAlgorithm.
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call { func (mr *MockSigningKeyMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigningKey)(nil).SignatureAlgorithm))
} }
// Signer mocks base method. // MockKey is a mock of Key interface.
func (m *MockSigner) Signer() jose.Signer { type MockKey struct {
ctrl *gomock.Controller
recorder *MockKeyMockRecorder
}
// MockKeyMockRecorder is the mock recorder for MockKey.
type MockKeyMockRecorder struct {
mock *MockKey
}
// NewMockKey creates a new mock instance.
func NewMockKey(ctrl *gomock.Controller) *MockKey {
mock := &MockKey{ctrl: ctrl}
mock.recorder = &MockKeyMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockKey) EXPECT() *MockKeyMockRecorder {
return m.recorder
}
// Algorithm mocks base method.
func (m *MockKey) Algorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer") ret := m.ctrl.Call(m, "Algorithm")
ret0, _ := ret[0].(jose.Signer) ret0, _ := ret[0].(jose.SignatureAlgorithm)
return ret0 return ret0
} }
// Signer indicates an expected call of Signer. // Algorithm indicates an expected call of Algorithm.
func (mr *MockSignerMockRecorder) Signer() *gomock.Call { func (mr *MockKeyMockRecorder) Algorithm() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockSigner)(nil).Signer)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockKey)(nil).Algorithm))
}
// ID mocks base method.
func (m *MockKey) ID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ID")
ret0, _ := ret[0].(string)
return ret0
}
// ID indicates an expected call of ID.
func (mr *MockKeyMockRecorder) ID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockKey)(nil).ID))
}
// Key mocks base method.
func (m *MockKey) Key() interface{} {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Key")
ret0, _ := ret[0].(interface{})
return ret0
}
// Key indicates an expected call of Key.
func (mr *MockKeyMockRecorder) Key() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockKey)(nil).Key))
}
// Use mocks base method.
func (m *MockKey) Use() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Use")
ret0, _ := ret[0].(string)
return ret0
}
// Use indicates an expected call of Use.
func (mr *MockKeyMockRecorder) Use() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Use", reflect.TypeOf((*MockKey)(nil).Use))
} }

View file

@ -174,21 +174,6 @@ func (mr *MockStorageMockRecorder) GetKeyByIDAndUserID(arg0, arg1, arg2 interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndUserID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndUserID), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndUserID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndUserID), arg0, arg1, arg2)
} }
// GetKeySet mocks base method.
func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet.
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
}
// GetPrivateClaimsFromScopes mocks base method. // GetPrivateClaimsFromScopes mocks base method.
func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) { func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -204,18 +189,6 @@ func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
} }
// GetSigningKey mocks base method.
func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "GetSigningKey", arg0, arg1)
}
// GetSigningKey indicates an expected call of GetSigningKey.
func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0, arg1)
}
// Health mocks base method. // Health mocks base method.
func (m *MockStorage) Health(arg0 context.Context) error { func (m *MockStorage) Health(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -230,6 +203,21 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
} }
// KeySet mocks base method.
func (m *MockStorage) KeySet(arg0 context.Context) ([]op.Key, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeySet", arg0)
ret0, _ := ret[0].([]op.Key)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// KeySet indicates an expected call of KeySet.
func (mr *MockStorageMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockStorage)(nil).KeySet), arg0)
}
// RevokeToken mocks base method. // RevokeToken mocks base method.
func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error { func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -300,6 +288,36 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromToken(arg0, arg1, arg2, arg3,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromToken), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromToken), arg0, arg1, arg2, arg3, arg4)
} }
// SignatureAlgorithms mocks base method.
func (m *MockStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
ret0, _ := ret[0].([]jose.SignatureAlgorithm)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
func (mr *MockStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockStorage)(nil).SignatureAlgorithms), arg0)
}
// SigningKey mocks base method.
func (m *MockStorage) SigningKey(arg0 context.Context) (op.SigningKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SigningKey", arg0)
ret0, _ := ret[0].(op.SigningKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SigningKey indicates an expected call of SigningKey.
func (mr *MockStorageMockRecorder) SigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SigningKey", reflect.TypeOf((*MockStorage)(nil).SigningKey), arg0)
}
// TerminateSession mocks base method. // TerminateSession mocks base method.
func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error { func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -3,11 +3,10 @@ package mock
import ( import (
"context" "context"
"errors" "errors"
"github.com/caos/oidc/pkg/oidc"
"testing" "testing"
"time" "time"
"gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/oidc"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -40,12 +39,12 @@ func NewMockStorageAny(t *testing.T) op.Storage {
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage { func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
m := NewStorage(t) m := NewStorage(t)
ExpectSigningKeyInvalid(m) //ExpectSigningKeyInvalid(m)
return m return m
} }
func NewMockStorageSigningKey(t *testing.T) op.Storage { func NewMockStorageSigningKey(t *testing.T) op.Storage {
m := NewStorage(t) m := NewStorage(t)
ExpectSigningKey(m) //ExpectSigningKey(m)
return m return m
} }
@ -83,23 +82,24 @@ func ExpectValidClientID(s op.Storage) {
}) })
} }
func ExpectSigningKeyInvalid(s op.Storage) { //
mockS := s.(*MockStorage) //func ExpectSigningKeyInvalid(s op.Storage) {
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn( // mockS := s.(*MockStorage)
func(_ context.Context, keyCh chan<- jose.SigningKey) { // mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
keyCh <- jose.SigningKey{} // func(_ context.Context, keyCh chan<- jose.SigningKey) {
}, // keyCh <- jose.SigningKey{}
) // },
} // )
//}
func ExpectSigningKey(s op.Storage) { //
mockS := s.(*MockStorage) //func ExpectSigningKey(s op.Storage) {
mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn( // mockS := s.(*MockStorage)
func(_ context.Context, keyCh chan<- jose.SigningKey) { // mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any()).DoAndReturn(
keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")} // func(_ context.Context, keyCh chan<- jose.SigningKey) {
}, // keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}
) // },
} // )
//}
type ConfClient struct { type ConfClient struct {
id string id string

View file

@ -46,11 +46,10 @@ type OpenIDProvider interface {
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Encoder() httphelper.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier(context.Context) IDTokenHintVerifier
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier(context.Context) AccessTokenVerifier
Crypto() Crypto Crypto() Crypto
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
Signer() Signer
Probes() []ProbesFn Probes() []ProbesFn
HttpHandler() http.Handler HttpHandler() http.Handler
} }
@ -62,31 +61,26 @@ var allowAllOrigins = func(_ string) bool {
} }
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router {
intercept := buildInterceptor(interceptors...)
router := mux.NewRouter() router := mux.NewRouter()
router.Use(handlers.CORS( router.Use(intercept(o.IssuerFromRequest, interceptors...))
handlers.AllowCredentials(),
handlers.AllowedHeaders([]string{"authorization", "content-type"}),
handlers.AllowedOriginValidator(allowAllOrigins),
))
router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(healthEndpoint, healthHandler)
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Signer())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
router.Handle(o.AuthorizationEndpoint().Relative(), intercept(authorizeHandler(o))) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").Handler(intercept(authorizeCallbackHandler(o))) router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o))
router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o))) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
return router return router
} }
//AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login //AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
func AuthCallbackURL(o OpenIDProvider) func(string) string { func AuthCallbackURL(o OpenIDProvider) func(context.Context, string) string {
return func(requestID string) string { return func(ctx context.Context, requestID string) string {
return o.AuthorizationEndpoint().Absolute(o.Issuer()) + authCallbackPathSuffix + "?id=" + requestID return o.AuthorizationEndpoint().Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
} }
} }
@ -95,7 +89,6 @@ func authCallbackPath(o OpenIDProvider) string {
} }
type Config struct { type Config struct {
Issuer string
CryptoKey [32]byte CryptoKey [32]byte
DefaultLogoutRedirectURI string DefaultLogoutRedirectURI string
CodeMethodS256 bool CodeMethodS256 bool
@ -117,13 +110,16 @@ type endpoints struct {
JwksURI Endpoint JwksURI Endpoint
} }
func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opOpts ...Option) (OpenIDProvider, error) { func NewOpenIDProvider(ctx context.Context, issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
err := ValidateIssuer(config.Issuer) return newProvider(ctx, config, storage, StaticIssuer(issuer), opOpts...)
if err != nil { }
return nil, err
}
o := &openidProvider{ func NewDynamicOpenIDProvider(ctx context.Context, path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
return newProvider(ctx, config, storage, IssuerFromHost(path), opOpts...)
}
func newProvider(ctx context.Context, config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
o := &Provider{
config: config, config: config,
storage: storage, storage: storage,
endpoints: DefaultEndpoints, endpoints: DefaultEndpoints,
@ -136,9 +132,10 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
} }
} }
keyCh := make(chan jose.SigningKey) o.issuer, err = issuer(o.insecure)
go storage.GetSigningKey(ctx, keyCh) if err != nil {
o.signer = NewSigner(ctx, storage, keyCh) return nil, err
}
o.httpHandler = CreateRouter(o, o.interceptors...) o.httpHandler = CreateRouter(o, o.interceptors...)
@ -152,171 +149,159 @@ func NewOpenIDProvider(ctx context.Context, config *Config, storage Storage, opO
return o, nil return o, nil
} }
type openidProvider struct { type Provider struct {
config *Config config *Config
endpoints *endpoints issuer IssuerFromRequest
storage Storage insecure bool
signer Signer endpoints *endpoints
idTokenHintVerifier IDTokenHintVerifier storage Storage
jwtProfileVerifier JWTProfileVerifier keySet *openIDKeySet
accessTokenVerifier AccessTokenVerifier crypto Crypto
keySet *openIDKeySet httpHandler http.Handler
crypto Crypto decoder *schema.Decoder
httpHandler http.Handler encoder *schema.Encoder
decoder *schema.Decoder interceptors []HttpInterceptor
encoder *schema.Encoder timer <-chan time.Time
interceptors []HttpInterceptor
timer <-chan time.Time
} }
func (o *openidProvider) Issuer() string { func (o *Provider) IssuerFromRequest(r *http.Request) string {
return o.config.Issuer return o.issuer(r)
} }
func (o *openidProvider) AuthorizationEndpoint() Endpoint { func (o *Provider) Insecure() bool {
return o.insecure
}
func (o *Provider) AuthorizationEndpoint() Endpoint {
return o.endpoints.Authorization return o.endpoints.Authorization
} }
func (o *openidProvider) TokenEndpoint() Endpoint { func (o *Provider) TokenEndpoint() Endpoint {
return o.endpoints.Token return o.endpoints.Token
} }
func (o *openidProvider) IntrospectionEndpoint() Endpoint { func (o *Provider) IntrospectionEndpoint() Endpoint {
return o.endpoints.Introspection return o.endpoints.Introspection
} }
func (o *openidProvider) UserinfoEndpoint() Endpoint { func (o *Provider) UserinfoEndpoint() Endpoint {
return o.endpoints.Userinfo return o.endpoints.Userinfo
} }
func (o *openidProvider) RevocationEndpoint() Endpoint { func (o *Provider) RevocationEndpoint() Endpoint {
return o.endpoints.Revocation return o.endpoints.Revocation
} }
func (o *openidProvider) EndSessionEndpoint() Endpoint { func (o *Provider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession return o.endpoints.EndSession
} }
func (o *openidProvider) KeysEndpoint() Endpoint { func (o *Provider) KeysEndpoint() Endpoint {
return o.endpoints.JwksURI return o.endpoints.JwksURI
} }
func (o *openidProvider) AuthMethodPostSupported() bool { func (o *Provider) AuthMethodPostSupported() bool {
return o.config.AuthMethodPost return o.config.AuthMethodPost
} }
func (o *openidProvider) CodeMethodS256Supported() bool { func (o *Provider) CodeMethodS256Supported() bool {
return o.config.CodeMethodS256 return o.config.CodeMethodS256
} }
func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool { func (o *Provider) AuthMethodPrivateKeyJWTSupported() bool {
return o.config.AuthMethodPrivateKeyJWT return o.config.AuthMethodPrivateKeyJWT
} }
func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string { func (o *Provider) TokenEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"} return []string{"RS256"}
} }
func (o *openidProvider) GrantTypeRefreshTokenSupported() bool { func (o *Provider) GrantTypeRefreshTokenSupported() bool {
return o.config.GrantTypeRefreshToken return o.config.GrantTypeRefreshToken
} }
func (o *openidProvider) GrantTypeTokenExchangeSupported() bool { func (o *Provider) GrantTypeTokenExchangeSupported() bool {
return false return false
} }
func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool { func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
return true return true
} }
func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true return true
} }
func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string { func (o *Provider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"} return []string{"RS256"}
} }
func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool { func (o *Provider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
return true return true
} }
func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string { func (o *Provider) RevocationEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"} return []string{"RS256"}
} }
func (o *openidProvider) RequestObjectSupported() bool { func (o *Provider) RequestObjectSupported() bool {
return o.config.RequestObjectSupported return o.config.RequestObjectSupported
} }
func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string { func (o *Provider) RequestObjectSigningAlgorithmsSupported() []string {
return []string{"RS256"} return []string{"RS256"}
} }
func (o *openidProvider) SupportedUILocales() []language.Tag { func (o *Provider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales return o.config.SupportedUILocales
} }
func (o *openidProvider) Storage() Storage { func (o *Provider) Storage() Storage {
return o.storage return o.storage
} }
func (o *openidProvider) Decoder() httphelper.Decoder { func (o *Provider) Decoder() httphelper.Decoder {
return o.decoder return o.decoder
} }
func (o *openidProvider) Encoder() httphelper.Encoder { func (o *Provider) Encoder() httphelper.Encoder {
return o.encoder return o.encoder
} }
func (o *openidProvider) IDTokenHintVerifier() IDTokenHintVerifier { func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier {
if o.idTokenHintVerifier == nil { return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet())
o.idTokenHintVerifier = NewIDTokenHintVerifier(o.Issuer(), o.openIDKeySet())
}
return o.idTokenHintVerifier
} }
func (o *openidProvider) JWTProfileVerifier() JWTProfileVerifier { func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier {
if o.jwtProfileVerifier == nil { return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
o.jwtProfileVerifier = NewJWTProfileVerifier(o.Storage(), o.Issuer(), 1*time.Hour, time.Second)
}
return o.jwtProfileVerifier
} }
func (o *openidProvider) AccessTokenVerifier() AccessTokenVerifier { func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier {
if o.accessTokenVerifier == nil { return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet())
o.accessTokenVerifier = NewAccessTokenVerifier(o.Issuer(), o.openIDKeySet())
}
return o.accessTokenVerifier
} }
func (o *openidProvider) openIDKeySet() oidc.KeySet { func (o *Provider) openIDKeySet() oidc.KeySet {
if o.keySet == nil { if o.keySet == nil {
o.keySet = &openIDKeySet{o.Storage()} o.keySet = &openIDKeySet{o.Storage()}
} }
return o.keySet return o.keySet
} }
func (o *openidProvider) Crypto() Crypto { func (o *Provider) Crypto() Crypto {
return o.crypto return o.crypto
} }
func (o *openidProvider) DefaultLogoutRedirectURI() string { func (o *Provider) DefaultLogoutRedirectURI() string {
return o.config.DefaultLogoutRedirectURI return o.config.DefaultLogoutRedirectURI
} }
func (o *openidProvider) Signer() Signer { func (o *Provider) Probes() []ProbesFn {
return o.signer
}
func (o *openidProvider) Probes() []ProbesFn {
return []ProbesFn{ return []ProbesFn{
ReadySigner(o.Signer()),
ReadyStorage(o.Storage()), ReadyStorage(o.Storage()),
} }
} }
func (o *openidProvider) HttpHandler() http.Handler { func (o *Provider) HttpHandler() http.Handler {
return o.httpHandler return o.httpHandler
} }
@ -327,22 +312,31 @@ type openIDKeySet struct {
//VerifySignature implements the oidc.KeySet interface //VerifySignature implements the oidc.KeySet interface
//providing an implementation for the keys stored in the OP Storage interface //providing an implementation for the keys stored in the OP Storage interface
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
keySet, err := o.Storage.GetKeySet(ctx) keySet, err := o.Storage.KeySet(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching keys: %w", err) return nil, fmt.Errorf("error fetching keys: %w", err)
} }
keyID, alg := oidc.GetKeyIDAndAlg(jws) keyID, alg := oidc.GetKeyIDAndAlg(jws)
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keySet.Keys...) key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, jsonWebKeySet(keySet).Keys...)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid signature: %w", err) return nil, fmt.Errorf("invalid signature: %w", err)
} }
return jws.Verify(&key) return jws.Verify(&key)
} }
type Option func(o *openidProvider) error type Option func(o *Provider) error
//WithAllowInsecure allows the use of http (instead of https) for issuers
//this is not recommended for production use and violates the OIDC specification
func WithAllowInsecure() Option {
return func(o *Provider) error {
o.insecure = true
return nil
}
}
func WithCustomAuthEndpoint(endpoint Endpoint) Option { func WithCustomAuthEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -352,7 +346,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option {
} }
func WithCustomTokenEndpoint(endpoint Endpoint) Option { func WithCustomTokenEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -362,7 +356,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option {
} }
func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -372,7 +366,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option {
} }
func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -382,7 +376,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
} }
func WithCustomRevocationEndpoint(endpoint Endpoint) Option { func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -392,7 +386,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
} }
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -402,7 +396,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
} }
func WithCustomKeysEndpoint(endpoint Endpoint) Option { func WithCustomKeysEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
return err return err
} }
@ -412,7 +406,7 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
} }
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
o.endpoints.Authorization = auth o.endpoints.Authorization = auth
o.endpoints.Token = token o.endpoints.Token = token
o.endpoints.Userinfo = userInfo o.endpoints.Userinfo = userInfo
@ -424,24 +418,23 @@ func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys End
} }
func WithHttpInterceptors(interceptors ...HttpInterceptor) Option { func WithHttpInterceptors(interceptors ...HttpInterceptor) Option {
return func(o *openidProvider) error { return func(o *Provider) error {
o.interceptors = append(o.interceptors, interceptors...) o.interceptors = append(o.interceptors, interceptors...)
return nil return nil
} }
} }
func buildInterceptor(interceptors ...HttpInterceptor) func(http.HandlerFunc) http.Handler { func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
return func(handlerFunc http.HandlerFunc) http.Handler { cors := handlers.CORS(
handler := handlerFuncToHandler(handlerFunc) handlers.AllowCredentials(),
handlers.AllowedHeaders([]string{"authorization", "content-type"}),
handlers.AllowedOriginValidator(allowAllOrigins),
)
issuerInterceptor := NewIssuerInterceptor(i)
return func(handler http.Handler) http.Handler {
for i := len(interceptors) - 1; i >= 0; i-- { for i := len(interceptors) - 1; i >= 0; i-- {
handler = interceptors[i](handler) handler = interceptors[i](handler)
} }
return handler return cors(issuerInterceptor.Handler(handler))
} }
} }
func handlerFuncToHandler(handlerFunc http.HandlerFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerFunc(w, r)
})
}

View file

@ -31,14 +31,6 @@ func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
ok(w) ok(w)
} }
func ReadySigner(s Signer) ProbesFn {
return func(ctx context.Context) error {
if s == nil {
return errors.New("no signer")
}
return s.Health(ctx)
}
}
func ReadyStorage(s Storage) ProbesFn { func ReadyStorage(s Storage) ProbesFn {
return func(ctx context.Context) error { return func(ctx context.Context) error {
if s == nil { if s == nil {

View file

@ -11,7 +11,7 @@ import (
type SessionEnder interface { type SessionEnder interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Storage() Storage Storage() Storage
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier(context.Context) IDTokenHintVerifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
} }
@ -62,7 +62,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
if req.IdTokenHint == "" { if req.IdTokenHint == "" {
return session, nil return session, nil
} }
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier()) claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
} }

View file

@ -1,82 +1,38 @@
package op package op
import ( import (
"context"
"errors" "errors"
"github.com/caos/logging"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
type Signer interface { var (
Health(ctx context.Context) error ErrSignerCreationFailed = errors.New("signer creation failed")
Signer() jose.Signer )
type SigningKey interface {
SignatureAlgorithm() jose.SignatureAlgorithm SignatureAlgorithm() jose.SignatureAlgorithm
Key() interface{}
ID() string
} }
type tokenSigner struct { func SignerFromKey(key SigningKey) (jose.Signer, error) {
signer jose.Signer signer, err := jose.NewSigner(jose.SigningKey{
storage AuthStorage Algorithm: key.SignatureAlgorithm(),
alg jose.SignatureAlgorithm Key: &jose.JSONWebKey{
} Key: key.Key(),
KeyID: key.ID(),
func NewSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer { },
s := &tokenSigner{ }, &jose.SignerOptions{})
storage: storage,
}
select {
case <-ctx.Done():
return nil
case key := <-keyCh:
s.exchangeSigningKey(key)
}
go s.refreshSigningKey(ctx, keyCh)
return s
}
func (s *tokenSigner) Health(_ context.Context) error {
if s.signer == nil {
return errors.New("no signer")
}
if string(s.alg) == "" {
return errors.New("no signing algorithm")
}
return nil
}
func (s *tokenSigner) Signer() jose.Signer {
return s.signer
}
func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
for {
select {
case <-ctx.Done():
return
case key := <-keyCh:
s.exchangeSigningKey(key)
}
}
}
func (s *tokenSigner) exchangeSigningKey(key jose.SigningKey) {
s.alg = key.Algorithm
if key.Algorithm == "" || key.Key == nil {
s.signer = nil
logging.Warn("signer has no key")
return
}
var err error
s.signer, err = jose.NewSigner(key, &jose.SignerOptions{})
if err != nil { if err != nil {
logging.New().WithError(err).Error("error creating signer") return nil, ErrSignerCreationFailed //TODO: log / wrap error?
return
} }
logging.Info("signer exchanged signing key") return signer, nil
} }
func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { type Key interface {
return s.alg ID() string
Algorithm() jose.SignatureAlgorithm
Use() string
Key() interface{}
} }

View file

@ -23,8 +23,9 @@ type AuthStorage interface {
TerminateSession(ctx context.Context, userID string, clientID string) error TerminateSession(ctx context.Context, userID string, clientID string) error
RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error
GetSigningKey(context.Context, chan<- jose.SigningKey) SigningKey(context.Context) (SigningKey, error)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error) SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
KeySet(context.Context) ([]Key, error)
} }
type OPStorage interface { type OPStorage interface {

View file

@ -10,8 +10,6 @@ import (
) )
type TokenCreator interface { type TokenCreator interface {
Issuer() string
Signer() Signer
Storage() Storage Storage() Storage
Crypto() Crypto Crypto() Crypto
} }
@ -32,7 +30,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
return nil, err return nil, err
} }
} }
idToken, err := CreateIDToken(ctx, creator.Issuer(), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer(), client) idToken, err := CreateIDToken(ctx, IssuerFromContext(ctx), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -84,7 +82,7 @@ func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTok
} }
validity = exp.Add(clockSkew).Sub(time.Now().UTC()) validity = exp.Add(clockSkew).Sub(time.Now().UTC())
if accessTokenType == AccessTokenTypeJWT { if accessTokenType == AccessTokenTypeJWT {
accessToken, err = CreateJWT(ctx, creator.Issuer(), tokenRequest, exp, id, creator.Signer(), client, creator.Storage()) accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage())
return return
} }
accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto()) accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
@ -95,7 +93,7 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
return crypto.Encrypt(tokenID + ":" + subject) return crypto.Encrypt(tokenID + ":" + subject)
} }
func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, signer Signer, client Client, storage Storage) (string, error) { func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client Client, storage Storage) (string, error) {
claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew()) claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew())
if client != nil { if client != nil {
restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes()) restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes())
@ -105,7 +103,15 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
} }
claims.SetPrivateClaims(privateClaims) claims.SetPrivateClaims(privateClaims)
} }
return crypto.Sign(claims, signer.Signer()) signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
signer, err := SignerFromKey(signingKey)
if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
} }
type IDTokenRequest interface { type IDTokenRequest interface {
@ -117,7 +123,7 @@ type IDTokenRequest interface {
GetSubject() string GetSubject() string
} }
func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer, client Client) (string, error) { func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) {
exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity) exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity)
var acr, nonce string var acr, nonce string
if authRequest, ok := request.(AuthRequest); ok { if authRequest, ok := request.(AuthRequest); ok {
@ -126,8 +132,12 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
} }
claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew()) claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew())
scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes()) scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes())
signingKey, err := storage.SigningKey(ctx)
if err != nil {
return "", err
}
if accessToken != "" { if accessToken != "" {
atHash, err := oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) atHash, err := oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
@ -145,14 +155,17 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
claims.SetUserinfo(userInfo) claims.SetUserinfo(userInfo)
} }
if code != "" { if code != "" {
codeHash, err := oidc.ClaimHash(code, signer.SignatureAlgorithm()) codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetCodeHash(codeHash) claims.SetCodeHash(codeHash)
} }
signer, err := SignerFromKey(signingKey)
return crypto.Sign(claims, signer.Signer()) if err != nil {
return "", err
}
return crypto.Sign(claims, signer)
} }
func removeUserinfoScopes(scopes []string) []string { func removeUserinfoScopes(scopes []string) []string {

View file

@ -1,6 +1,7 @@
package op package op
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
@ -13,12 +14,12 @@ type Introspector interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier(context.Context) AccessTokenVerifier
} }
type IntrospectorJWTProfile interface { type IntrospectorJWTProfile interface {
Introspector Introspector
JWTProfileVerifier() JWTProfileVerifier JWTProfileVerifier(context.Context) JWTProfileVerifier
} }
func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *http.Request) { func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *http.Request) {
@ -62,7 +63,7 @@ func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector)
return "", "", errors.New("unable to parse request") return "", "", errors.New("unable to parse request")
} }
if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" { if introspectorJWTProfile, ok := introspector.(IntrospectorJWTProfile); ok && req.ClientAssertion != "" {
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier()) profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, introspectorJWTProfile.JWTProfileVerifier(r.Context()))
if err == nil { if err == nil {
return req.Token, profile.Issuer, nil return req.Token, profile.Issuer, nil
} }

View file

@ -11,7 +11,7 @@ import (
type JWTAuthorizationGrantExchanger interface { type JWTAuthorizationGrantExchanger interface {
Exchanger Exchanger
JWTProfileVerifier() JWTProfileVerifier JWTProfileVerifier(context.Context) JWTProfileVerifier
} }
//JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 //JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
@ -21,7 +21,7 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
RequestError(w, r, err) RequestError(w, r, err)
} }
tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier()) tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
if err != nil { if err != nil {
RequestError(w, r, err) RequestError(w, r, err)
return return

View file

@ -10,10 +10,8 @@ import (
) )
type Exchanger interface { type Exchanger interface {
Issuer() string
Storage() Storage Storage() Storage
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Signer() Signer
Crypto() Crypto Crypto() Crypto
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
AuthMethodPrivateKeyJWTSupported() bool AuthMethodPrivateKeyJWTSupported() bool
@ -111,7 +109,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.C
//AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously //AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously
//registered public key (JWT Profile) //registered public key (JWT Profile)
func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) { func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) {
jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier()) jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -14,14 +14,14 @@ type Revoker interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier(context.Context) AccessTokenVerifier
AuthMethodPrivateKeyJWTSupported() bool AuthMethodPrivateKeyJWTSupported() bool
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
} }
type RevokerJWTProfile interface { type RevokerJWTProfile interface {
Revoker Revoker
JWTProfileVerifier() JWTProfileVerifier JWTProfileVerifier(context.Context) JWTProfileVerifier
} }
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
@ -67,7 +67,7 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token
if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() { if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
} }
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier()) profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier(r.Context()))
if err == nil { if err == nil {
return req.Token, req.TokenTypeHint, profile.Issuer, nil return req.Token, req.TokenTypeHint, profile.Issuer, nil
} }
@ -128,7 +128,7 @@ func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider Use
} }
return splitToken[0], splitToken[1], true return splitToken[0], splitToken[1], true
} }
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil { if err != nil {
return "", "", false return "", "", false
} }

View file

@ -14,7 +14,7 @@ type UserinfoProvider interface {
Decoder() httphelper.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier(context.Context) AccessTokenVerifier
} }
func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
@ -81,7 +81,7 @@ func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider
} }
return splitToken[0], splitToken[1], true return splitToken[0], splitToken[1], true
} }
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil { if err != nil {
return "", "", false return "", "", false
} }