initial commit

This commit is contained in:
Livio Amstutz 2020-01-31 15:22:16 +01:00
commit 6d0890e280
68 changed files with 5986 additions and 0 deletions

198
pkg/op/authrequest.go Normal file
View file

@ -0,0 +1,198 @@
package op
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/gorilla/mux"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type Authorizer interface {
Storage() Storage
Decoder() *schema.Decoder
Encoder() *schema.Encoder
Signer() Signer
Crypto() Crypto
Issuer() string
}
type ValidationAuthorizer interface {
Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
}
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm()
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
return
}
authReq := new(oidc.AuthRequest)
err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return
}
validation := ValidateAuthRequest
if validater, ok := authorizer.(ValidationAuthorizer); ok {
validation = validater.ValidateAuthRequest
}
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder())
return
}
RedirectToLogin(req.GetID(), client, w, r)
}
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
return err
}
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
return err
}
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
return err
}
// if NeedsExistingSession(authReq) {
// session, err := storage.CheckSession(authReq.IDTokenHint)
// if err != nil {
// return err
// }
// }
return nil
}
func ValidateAuthReqScopes(scopes []string) error {
if len(scopes) == 0 {
return ErrInvalidRequest("scope missing")
}
if !utils.Contains(scopes, oidc.ScopeOpenID) {
return ErrInvalidRequest("scope openid missing")
}
return nil
}
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" {
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
}
client, err := storage.GetClientByClientID(ctx, client_id)
if err != nil {
return ErrServerError(err.Error())
}
if !utils.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
}
if strings.HasPrefix(uri, "https://") {
return nil
}
if responseType == oidc.ResponseTypeCode {
if strings.HasPrefix(uri, "http://") && IsConfidentialType(client) {
return nil
}
if client.ApplicationType() == ApplicationTypeNative {
return nil
}
return ErrInvalidRequest("redirect_uri not allowed")
} else {
if client.ApplicationType() != ApplicationTypeNative {
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
}
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
}
}
return nil
}
func ValidateAuthReqResponseType(responseType oidc.ResponseType) error {
if responseType == "" {
return ErrInvalidRequest("response_type empty")
}
return nil
}
func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
login := client.LoginURL(authReqID)
http.Redirect(w, r, login, http.StatusFound)
}
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
params := mux.Vars(r)
id := params["id"]
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
if !authReq.Done() {
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r)
}
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil {
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
AuthResponseCode(w, r, authReq, authorizer)
return
}
AuthResponseToken(w, r, authReq, authorizer, client)
return
}
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
if authReq.GetState() != "" {
callback = callback + "&state=" + authReq.GetState()
}
http.Redirect(w, r, callback, http.StatusFound)
}
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
http.Redirect(w, r, callback, http.StatusFound)
}
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}

296
pkg/op/authrequest_test.go Normal file
View file

@ -0,0 +1,296 @@
package op_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock"
)
func TestAuthorize(t *testing.T) {
// testCallback := func(t *testing.T, clienID string) callbackHandler {
// return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
// // require.Equal(t, clientID, client.)
// }
// }
// testErr := func(t *testing.T, expected error) errorHandler {
// return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
// require.Equal(t, expected, err)
// }
// }
type args struct {
w http.ResponseWriter
r *http.Request
authorizer op.Authorizer
}
tests := []struct {
name string
args args
}{
{
"parsing fails",
args{
httptest.NewRecorder(),
&http.Request{Method: "POST", Body: nil},
mock.NewAuthorizerExpectValid(t, true),
// testCallback(t, ""),
// testErr(t, ErrInvalidRequest("cannot parse form")),
},
},
{
"decoding fails",
args{
httptest.NewRecorder(),
func() *http.Request {
r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return r
}(),
mock.NewAuthorizerExpectValid(t, true),
// testCallback(t, ""),
// testErr(t, ErrInvalidRequest("cannot parse auth request")),
},
},
// {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
func TestValidateAuthRequest(t *testing.T) {
type args struct {
authRequest *oidc.AuthRequest
storage op.Storage
}
tests := []struct {
name string
args args
wantErr bool
}{
//TODO:
// {
// "oauth2 spec"
// }
{
"scope missing fails",
args{&oidc.AuthRequest{}, nil},
true,
},
{
"scope openid missing fails",
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil},
true,
},
{
"response_type missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil},
true,
},
{
"client_id missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil},
true,
},
{
"redirect_uri missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidateAuthReqScopes(t *testing.T) {
type args struct {
scopes []string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"scopes missing fails", args{}, true,
},
{
"scope openid missing fails", args{[]string{"email"}}, true,
},
{
"scope ok", args{[]string{"openid"}}, false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidateAuthReqRedirectURI(t *testing.T) {
type args struct {
uri string
clientID string
responseType oidc.ResponseType
storage op.OPStorage
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"empty fails",
args{"", "", oidc.ResponseTypeCode, nil},
true,
},
{
"unregistered fails",
args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
true,
},
{
"storage error fails",
args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)},
true,
},
{
"code flow registered http not confidential fails",
args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
true,
},
{
"code flow registered http confidential ok",
args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
false,
},
{
"code flow registered custom not native fails",
args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
true,
},
{
"code flow registered custom native ok",
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
false,
},
{
"implicit flow registered ok",
args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
false,
},
{
"implicit flow registered http localhost native ok",
args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
false,
},
{
"implicit flow registered http localhost user agent fails",
args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
true,
},
{
"implicit flow http non localhost fails",
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
true,
},
{
"implicit flow custom fails",
args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqRedirectURI(nil, tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr {
t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr)
}
})
}
}
func TestRedirectToLogin(t *testing.T) {
type args struct {
authReqID string
client op.Client
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
{
"redirect ok",
args{
"id",
mock.NewClientExpectAny(t, op.ApplicationTypeNative),
httptest.NewRecorder(),
httptest.NewRequest("GET", "/authorize", nil),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusFound, rec.Code)
require.Equal(t, "/login?id=id", rec.Header().Get("location"))
})
}
}
func TestAuthorizeCallback(t *testing.T) {
type args struct {
w http.ResponseWriter
r *http.Request
authorizer op.Authorizer
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
func TestAuthResponse(t *testing.T) {
type args struct {
authReq op.AuthRequest
authorizer op.Authorizer
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
})
}
}

33
pkg/op/client.go Normal file
View file

@ -0,0 +1,33 @@
package op
import "time"
const (
ApplicationTypeWeb ApplicationType = iota
ApplicationTypeUserAgent
ApplicationTypeNative
AccessTokenTypeBearer AccessTokenType = iota
AccessTokenTypeJWT
)
type Client interface {
GetID() string
RedirectURIs() []string
ApplicationType() ApplicationType
GetAuthMethod() AuthMethod
LoginURL(string) string
AccessTokenType() AccessTokenType
AccessTokenLifetime() time.Duration
IDTokenLifetime() time.Duration
}
func IsConfidentialType(c Client) bool {
return c.ApplicationType() == ApplicationTypeWeb
}
type ApplicationType int
type AuthMethod string
type AccessTokenType int

54
pkg/op/config.go Normal file
View file

@ -0,0 +1,54 @@
package op
import (
"errors"
"net/url"
"os"
"strings"
)
type Configuration interface {
Issuer() string
AuthorizationEndpoint() Endpoint
TokenEndpoint() Endpoint
UserinfoEndpoint() Endpoint
KeysEndpoint() Endpoint
AuthMethodPostSupported() bool
Port() string
}
func ValidateIssuer(issuer string) error {
if issuer == "" {
return errors.New("missing issuer")
}
u, err := url.Parse(issuer)
if err != nil {
return errors.New("invalid url for issuer")
}
if u.Host == "" {
return errors.New("host for issuer missing")
}
if u.Scheme != "https" {
if !devLocalAllowed(u) {
return errors.New("scheme for issuer must be `https`")
}
}
if u.Fragment != "" || len(u.Query()) > 0 {
return errors.New("no fragments or query allowed for issuer")
}
return nil
}
func devLocalAllowed(url *url.URL) bool {
_, b := os.LookupEnv("CAOS_OIDC_DEV")
if !b {
return b
}
return url.Scheme == "http" &&
url.Host == "localhost" ||
url.Host == "127.0.0.1" ||
url.Host == "::1" ||
strings.HasPrefix(url.Host, "localhost:")
}

94
pkg/op/config_test.go Normal file
View file

@ -0,0 +1,94 @@
package op
import "testing"
import "os"
func TestValidateIssuer(t *testing.T) {
type args struct {
issuer string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"missing issuer fails",
args{""},
true,
},
{
"invalid url for issuer fails",
args{":issuer"},
true,
},
{
"invalid url for issuer fails",
args{":issuer"},
true,
},
{
"host for issuer missing fails",
args{"https:///issuer"},
true,
},
{
"host for not https fails",
args{"http://issuer.com"},
true,
},
{
"host with fragment fails",
args{"https://issuer.com/#issuer"},
true,
},
{
"host with query fails",
args{"https://issuer.com?issuer=me"},
true,
},
{
"host with https ok",
args{"https://issuer.com"},
false,
},
{
"localhost with http ok",
args{"http://localhost:9999"},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
type args struct {
issuer string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"localhost with http ok",
args{"http://localhost:9999"},
false,
},
}
os.Setenv("CAOS_OIDC_DEV", "")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

26
pkg/op/crypto.go Normal file
View file

@ -0,0 +1,26 @@
package op
import (
"github.com/caos/oidc/pkg/utils"
)
type Crypto interface {
Encrypt(string) (string, error)
Decrypt(string) (string, error)
}
type aesCrypto struct {
key string
}
func NewAESCrypto(key [32]byte) Crypto {
return &aesCrypto{key: string(key[:32])}
}
func (c *aesCrypto) Encrypt(s string) (string, error) {
return utils.EncryptAES(s, c.key)
}
func (c *aesCrypto) Decrypt(s string) (string, error) {
return utils.DecryptAES(s, c.key)
}

224
pkg/op/default_op.go Normal file
View file

@ -0,0 +1,224 @@
package op
import (
"context"
"net/http"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
)
const (
defaultAuthorizationEndpoint = "authorize"
defaulTokenEndpoint = "oauth/token"
defaultIntrospectEndpoint = "introspect"
defaultUserinfoEndpoint = "userinfo"
defaultKeysEndpoint = "keys"
AuthMethodBasic AuthMethod = "client_secret_basic"
AuthMethodPost = "client_secret_post"
AuthMethodNone = "none"
)
var (
DefaultEndpoints = &endpoints{
Authorization: defaultAuthorizationEndpoint,
Token: defaulTokenEndpoint,
IntrospectionEndpoint: defaultIntrospectEndpoint,
Userinfo: defaultUserinfoEndpoint,
JwksURI: defaultKeysEndpoint,
}
)
type DefaultOP struct {
config *Config
endpoints *endpoints
discoveryConfig *oidc.DiscoveryConfiguration
storage Storage
signer Signer
crypto Crypto
http *http.Server
decoder *schema.Decoder
encoder *schema.Encoder
}
type Config struct {
Issuer string
CryptoKey [32]byte
// ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes,
// ClaimsSupported: oidc.SupportedClaims,
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
// SubjectTypesSupported: []string{"public"},
// TokenEndpointAuthMethodsSupported:
Port string
}
type endpoints struct {
Authorization Endpoint
Token Endpoint
IntrospectionEndpoint Endpoint
Userinfo Endpoint
EndSessionEndpoint Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
}
type DefaultOPOpts func(o *DefaultOP) error
func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Authorization = endpoint
return nil
}
}
func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Token = endpoint
return nil
}
}
func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
return func(o *DefaultOP) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Userinfo = endpoint
return nil
}
}
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer)
if err != nil {
return nil, err
}
p := &DefaultOP{
config: config,
storage: storage,
endpoints: DefaultEndpoints,
}
p.signer, err = NewDefaultSigner(ctx, storage)
if err != nil {
return nil, err
}
for _, optFunc := range opOpts {
if err := optFunc(p); err != nil {
return nil, err
}
}
p.discoveryConfig = CreateDiscoveryConfig(p, p.signer)
router := CreateRouter(p)
p.http = &http.Server{
Addr: ":" + config.Port,
Handler: router,
}
p.decoder = schema.NewDecoder()
p.decoder.IgnoreUnknownKeys(true)
p.encoder = schema.NewEncoder()
p.crypto = NewAESCrypto(config.CryptoKey)
return p, nil
}
func (p *DefaultOP) Issuer() string {
return p.config.Issuer
}
func (p *DefaultOP) AuthorizationEndpoint() Endpoint {
return p.endpoints.Authorization
}
func (p *DefaultOP) TokenEndpoint() Endpoint {
return Endpoint(p.endpoints.Token)
}
func (p *DefaultOP) UserinfoEndpoint() Endpoint {
return Endpoint(p.endpoints.Userinfo)
}
func (p *DefaultOP) KeysEndpoint() Endpoint {
return Endpoint(p.endpoints.JwksURI)
}
func (p *DefaultOP) AuthMethodPostSupported() bool {
return true //TODO: config
}
func (p *DefaultOP) Port() string {
return p.config.Port
}
func (p *DefaultOP) HttpHandler() *http.Server {
return p.http
}
func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
Discover(w, p.discoveryConfig)
}
func (p *DefaultOP) Decoder() *schema.Decoder {
return p.decoder
}
func (p *DefaultOP) Encoder() *schema.Encoder {
return p.encoder
}
func (p *DefaultOP) Storage() Storage {
return p.storage
}
func (p *DefaultOP) Signer() Signer {
return p.signer
}
func (p *DefaultOP) Crypto() Crypto {
return p.crypto
}
func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) {
Keys(w, r, p)
}
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
Authorize(w, r, p)
}
func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) {
AuthorizeCallback(w, r, p)
}
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
reqType := r.FormValue("grant_type")
if reqType == "" {
ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing"))
return
}
if reqType == string(oidc.GrantTypeCode) {
CodeExchange(w, r, p)
return
}
TokenExchange(w, r, p)
}
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
Userinfo(w, r, p)
}

49
pkg/op/default_op_test.go Normal file
View file

@ -0,0 +1,49 @@
package op
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
)
func TestDefaultOP_HandleDiscovery(t *testing.T) {
type fields struct {
config *Config
endpoints *endpoints
discoveryConfig *oidc.DiscoveryConfiguration
storage Storage
http *http.Server
}
type args struct {
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
fields fields
args args
want string
wantCode int
}{
{"OK", fields{config: nil, endpoints: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"https://issuer.com"}`, 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &DefaultOP{
config: tt.fields.config,
endpoints: tt.fields.endpoints,
discoveryConfig: tt.fields.discoveryConfig,
storage: tt.fields.storage,
http: tt.fields.http,
}
p.HandleDiscovery(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, tt.want, rec.Body.String())
require.Equal(t, tt.wantCode, rec.Code)
})
}
}

119
pkg/op/discovery.go Normal file
View file

@ -0,0 +1,119 @@
package op
import (
"net/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
utils.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
// IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
// EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint),
// CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c),
ResponseTypesSupported: ResponseTypes(c),
GrantTypesSupported: GrantTypes(c),
ClaimsSupported: SupportedClaims(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
SubjectTypesSupported: SubjectTypes(c),
TokenEndpointAuthMethodsSupported: AuthMethods(c),
}
}
const (
ScopeOpenID = "openid"
ScopeProfile = "profile"
ScopeEmail = "email"
ScopePhone = "phone"
ScopeAddress = "address"
)
var DefaultSupportedScopes = []string{
ScopeOpenID,
ScopeProfile,
ScopeEmail,
ScopePhone,
ScopeAddress,
}
func Scopes(c Configuration) []string {
return DefaultSupportedScopes //TODO: config
}
func ResponseTypes(c Configuration) []string {
return []string{
"code",
"id_token",
// "code token",
// "code id_token",
"id_token token",
// "code id_token token"
}
}
func GrantTypes(c Configuration) []string {
return []string{
"client_credentials",
"authorization_code",
// "password",
"urn:ietf:params:oauth:grant-type:token-exchange",
}
}
func SupportedClaims(c Configuration) []string {
return []string{ //TODO: config
"sub",
"aud",
"exp",
"iat",
"iss",
"auth_time",
"nonce",
"acr",
"amr",
"c_hash",
"at_hash",
"act",
"scopes",
"client_id",
"azp",
"preferred_username",
"name",
"family_name",
"given_name",
"locale",
"email",
"email_verified",
"phone_number",
"phone_number_verified",
}
}
func SigAlgorithms(s Signer) []string {
return []string{string(s.SignatureAlgorithm())}
}
func SubjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func AuthMethods(c Configuration) []string {
authMethods := []string{
string(AuthMethodBasic),
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, string(AuthMethodPost))
}
return authMethods
}

235
pkg/op/discovery_test.go Normal file
View file

@ -0,0 +1,235 @@
package op_test
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
func TestDiscover(t *testing.T) {
type args struct {
w http.ResponseWriter
config *oidc.DiscoveryConfiguration
}
tests := []struct {
name string
args args
}{
{
"OK",
args{
httptest.NewRecorder(),
&oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
})
}
}
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
c op.Configuration
s op.Signer
}
tests := []struct {
name string
args args
want *oidc.DiscoveryConfiguration
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
}
})
}
}
func Test_scopes(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"default Scopes",
args{},
op.DefaultSupportedScopes,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("scopes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_ResponseTypes(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_GrantTypes(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
func TestSupportedClaims(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
}
})
}
}
func Test_SigAlgorithms(t *testing.T) {
m := mock.NewMockSigner(gomock.NewController((t)))
type args struct {
s op.Signer
}
tests := []struct {
name string
args args
want []string
}{
{
"",
args{func() op.Signer {
m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
return m
}()},
[]string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
}
})
}
}
func Test_SubjectTypes(t *testing.T) {
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"none",
args{},
[]string{"public"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_AuthMethods(t *testing.T) {
m := mock.NewMockConfiguration(gomock.NewController((t)))
type args struct {
c op.Configuration
}
tests := []struct {
name string
args args
want []string
}{
{
"imlicit basic",
args{func() op.Configuration {
m.EXPECT().AuthMethodPostSupported().Return(false)
return m
}()},
[]string{string(op.AuthMethodBasic)},
},
{
"basic and post",
args{func() op.Configuration {
m.EXPECT().AuthMethodPostSupported().Return(true)
return m
}()},
[]string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.AuthMethods(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("authMethods() = %v, want %v", got, tt.want)
}
})
}
}

25
pkg/op/endpoint.go Normal file
View file

@ -0,0 +1,25 @@
package op
import "strings"
type Endpoint string
func (e Endpoint) Relative() string {
return relativeEndpoint(string(e))
}
func (e Endpoint) Absolute(host string) string {
return absoluteEndpoint(host, string(e))
}
func (e Endpoint) Validate() error {
return nil //TODO:
}
func absoluteEndpoint(host, endpoint string) string {
return strings.TrimSuffix(host, "/") + relativeEndpoint(endpoint)
}
func relativeEndpoint(endpoint string) string {
return "/" + strings.TrimPrefix(endpoint, "/")
}

95
pkg/op/endpoint_test.go Normal file
View file

@ -0,0 +1,95 @@
package op_test
import (
"testing"
"github.com/caos/oidc/pkg/op"
)
func TestEndpoint_Relative(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
want string
}{
{
"without starting /",
op.Endpoint("test"),
"/test",
},
{
"with starting /",
op.Endpoint("/test"),
"/test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.e.Relative(); got != tt.want {
t.Errorf("Endpoint.Relative() = %v, want %v", got, tt.want)
}
})
}
}
func TestEndpoint_Absolute(t *testing.T) {
type args struct {
host string
}
tests := []struct {
name string
e op.Endpoint
args args
want string
}{
{
"no /",
op.Endpoint("test"),
args{"https://host"},
"https://host/test",
},
{
"endpoint without /",
op.Endpoint("test"),
args{"https://host/"},
"https://host/test",
},
{
"host without /",
op.Endpoint("/test"),
args{"https://host"},
"https://host/test",
},
{
"both /",
op.Endpoint("/test"),
args{"https://host/"},
"https://host/test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.e.Absolute(tt.args.host); got != tt.want {
t.Errorf("Endpoint.Absolute() = %v, want %v", got, tt.want)
}
})
}
}
//TODO: impl test
func TestEndpoint_Validate(t *testing.T) {
tests := []struct {
name string
e op.Endpoint
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

99
pkg/op/error.go Normal file
View file

@ -0,0 +1,99 @@
package op
import (
"fmt"
"net/http"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
const (
InvalidRequest errorType = "invalid_request"
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
redirectDisabled: true,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface {
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.state = authReq.GetState()
if authReq.GetRedirectURI() == "" || e.redirectDisabled {
http.Error(w, e.Description, http.StatusBadRequest)
return
}
params, err := utils.URLEncodeResponse(e, encoder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.GetRedirectURI()
responseType := authReq.GetResponseType()
if responseType == "" || responseType == oidc.ResponseTypeCode {
url += "?" + params
} else {
url += "#" + params
}
http.Redirect(w, r, url, http.StatusFound)
}
func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
w.WriteHeader(http.StatusBadRequest)
utils.MarshalJSON(w, e)
}
type OAuthError struct {
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"description" schema:"description"`
state string `json:"state" schema:"state"`
redirectDisabled bool
}
func (e *OAuthError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
}

19
pkg/op/keys.go Normal file
View file

@ -0,0 +1,19 @@
package op
import (
"net/http"
"github.com/caos/oidc/pkg/utils"
)
type KeyProvider interface {
Storage() Storage
}
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.Storage().GetKeySet(r.Context())
if err != nil {
}
utils.MarshalJSON(w, keySet)
}

View file

@ -0,0 +1,119 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Authorizer)
// Package mock is a generated GoMock package.
package mock
import (
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
schema "github.com/gorilla/schema"
reflect "reflect"
)
// MockAuthorizer is a mock of Authorizer interface
type MockAuthorizer struct {
ctrl *gomock.Controller
recorder *MockAuthorizerMockRecorder
}
// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer
type MockAuthorizerMockRecorder struct {
mock *MockAuthorizer
}
// NewMockAuthorizer creates a new mock instance
func NewMockAuthorizer(ctrl *gomock.Controller) *MockAuthorizer {
mock := &MockAuthorizer{ctrl: ctrl}
mock.recorder = &MockAuthorizerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder {
return m.recorder
}
// Crypto mocks base method
func (m *MockAuthorizer) Crypto() op.Crypto {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Crypto")
ret0, _ := ret[0].(op.Crypto)
return ret0
}
// Crypto indicates an expected call of Crypto
func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Crypto", reflect.TypeOf((*MockAuthorizer)(nil).Crypto))
}
// Decoder mocks base method
func (m *MockAuthorizer) Decoder() *schema.Decoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decoder")
ret0, _ := ret[0].(*schema.Decoder)
return ret0
}
// Decoder indicates an expected call of Decoder
func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decoder", reflect.TypeOf((*MockAuthorizer)(nil).Decoder))
}
// Encoder mocks base method
func (m *MockAuthorizer) Encoder() *schema.Encoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encoder")
ret0, _ := ret[0].(*schema.Encoder)
return ret0
}
// Encoder indicates an expected call of Encoder
func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
}
// Issuer mocks base method
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
}
// Signer mocks base method
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Signer")
ret0, _ := ret[0].(op.Signer)
return ret0
}
// Signer indicates an expected call of Signer
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
}
// Storage mocks base method
func (m *MockAuthorizer) Storage() op.Storage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Storage")
ret0, _ := ret[0].(op.Storage)
return ret0
}
// Storage indicates an expected call of Storage
func (mr *MockAuthorizerMockRecorder) Storage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockAuthorizer)(nil).Storage))
}

View file

@ -0,0 +1,89 @@
package mock
import (
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/schema"
"gopkg.in/square/go-jose.v2"
oidc "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
)
func NewAuthorizer(t *testing.T) op.Authorizer {
return NewMockAuthorizer(gomock.NewController(t))
}
func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
m := NewAuthorizer(t)
ExpectDecoder(m)
ExpectEncoder(m)
ExpectSigner(m, t)
ExpectStorage(m, t)
// ExpectErrorHandler(m, t, wantErr)
return m
}
// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
// m := NewAuthorizer(t)
// ExpectDecoderFails(m)
// ExpectEncoder(m)
// ExpectSigner(m, t)
// ExpectStorage(m, t)
// ExpectErrorHandler(m, t)
// return m
// }
func ExpectDecoder(a op.Authorizer) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
}
func ExpectEncoder(a op.Authorizer) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
}
func ExpectSigner(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Signer().DoAndReturn(
func() op.Signer {
return &Sig{}
})
}
// func ExpectErrorHandler(a op.Authorizer, t *testing.T, wantErr bool) {
// mockA := a.(*MockAuthorizer)
// mockA.EXPECT().ErrorHandler().AnyTimes().
// Return(func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
// if wantErr {
// require.Error(t, err)
// return
// }
// require.NoError(t, err)
// })
// }
type Sig struct{}
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256
}
func ExpectStorage(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
}
// func NewMockSignerAny(t *testing.T) op.Signer {
// m := NewMockSigner(gomock.NewController(t))
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
// return m
// }

29
pkg/op/mock/client.go Normal file
View file

@ -0,0 +1,29 @@
package mock
import (
"testing"
gomock "github.com/golang/mock/gomock"
op "github.com/caos/oidc/pkg/op"
)
func NewClient(t *testing.T) op.Client {
return NewMockClient(gomock.NewController(t))
}
func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
c := NewClient(t)
m := c.(*MockClient)
m.EXPECT().RedirectURIs().AnyTimes().Return([]string{
"https://registered.com/callback",
"http://registered.com/callback",
"http://localhost:9999/callback",
"custom://callback"})
m.EXPECT().ApplicationType().AnyTimes().Return(appType)
m.EXPECT().LoginURL(gomock.Any()).AnyTimes().DoAndReturn(
func(id string) string {
return "login?id=" + id
})
return c
}

147
pkg/op/mock/client.mock.go Normal file
View file

@ -0,0 +1,147 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Client)
// Package mock is a generated GoMock package.
package mock
import (
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockClient is a mock of Client interface
type MockClient struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder
}
// MockClientMockRecorder is the mock recorder for MockClient
type MockClientMockRecorder struct {
mock *MockClient
}
// NewMockClient creates a new mock instance
func NewMockClient(ctrl *gomock.Controller) *MockClient {
mock := &MockClient{ctrl: ctrl}
mock.recorder = &MockClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockClient) EXPECT() *MockClientMockRecorder {
return m.recorder
}
// AccessTokenLifetime mocks base method
func (m *MockClient) AccessTokenLifetime() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
}
// AccessTokenType mocks base method
func (m *MockClient) AccessTokenType() op.AccessTokenType {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenType")
ret0, _ := ret[0].(op.AccessTokenType)
return ret0
}
// AccessTokenType indicates an expected call of AccessTokenType
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
// ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ApplicationType")
ret0, _ := ret[0].(op.ApplicationType)
return ret0
}
// ApplicationType indicates an expected call of ApplicationType
func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
// GetAuthMethod mocks base method
func (m *MockClient) GetAuthMethod() op.AuthMethod {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthMethod")
ret0, _ := ret[0].(op.AuthMethod)
return ret0
}
// GetAuthMethod indicates an expected call of GetAuthMethod
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
}
// GetID mocks base method
func (m *MockClient) GetID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetID")
ret0, _ := ret[0].(string)
return ret0
}
// GetID indicates an expected call of GetID
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
}
// IDTokenLifetime mocks base method
func (m *MockClient) IDTokenLifetime() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// IDTokenLifetime indicates an expected call of IDTokenLifetime
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
}
// LoginURL mocks base method
func (m *MockClient) LoginURL(arg0 string) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginURL", arg0)
ret0, _ := ret[0].(string)
return ret0
}
// LoginURL indicates an expected call of LoginURL
func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
}
// RedirectURIs mocks base method
func (m *MockClient) RedirectURIs() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RedirectURIs")
ret0, _ := ret[0].([]string)
return ret0
}
// RedirectURIs indicates an expected call of RedirectURIs
func (mr *MockClientMockRecorder) RedirectURIs() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockClient)(nil).RedirectURIs))
}

View file

@ -0,0 +1,132 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Configuration)
// Package mock is a generated GoMock package.
package mock
import (
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockConfiguration is a mock of Configuration interface
type MockConfiguration struct {
ctrl *gomock.Controller
recorder *MockConfigurationMockRecorder
}
// MockConfigurationMockRecorder is the mock recorder for MockConfiguration
type MockConfigurationMockRecorder struct {
mock *MockConfiguration
}
// NewMockConfiguration creates a new mock instance
func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration {
mock := &MockConfiguration{ctrl: ctrl}
mock.recorder = &MockConfigurationMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder {
return m.recorder
}
// AuthMethodPostSupported mocks base method
func (m *MockConfiguration) AuthMethodPostSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthMethodPostSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported
func (mr *MockConfigurationMockRecorder) AuthMethodPostSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPostSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPostSupported))
}
// AuthorizationEndpoint mocks base method
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint
func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
}
// Issuer mocks base method
func (m *MockConfiguration) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
}
// KeysEndpoint mocks base method
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// KeysEndpoint indicates an expected call of KeysEndpoint
func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
}
// Port mocks base method
func (m *MockConfiguration) Port() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Port")
ret0, _ := ret[0].(string)
return ret0
}
// Port indicates an expected call of Port
func (mr *MockConfigurationMockRecorder) Port() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockConfiguration)(nil).Port))
}
// TokenEndpoint mocks base method
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// TokenEndpoint indicates an expected call of TokenEndpoint
func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
}
// UserinfoEndpoint mocks base method
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserinfoEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// UserinfoEndpoint indicates an expected call of UserinfoEndpoint
func (mr *MockConfigurationMockRecorder) UserinfoEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserinfoEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserinfoEndpoint))
}

7
pkg/op/mock/generate.go Normal file
View file

@ -0,0 +1,7 @@
package mock
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/caos/oidc/pkg/op Storage
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer

View file

@ -0,0 +1,79 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Signer)
// Package mock is a generated GoMock package.
package mock
import (
oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2"
reflect "reflect"
)
// MockSigner is a mock of Signer interface
type MockSigner struct {
ctrl *gomock.Controller
recorder *MockSignerMockRecorder
}
// MockSignerMockRecorder is the mock recorder for MockSigner
type MockSignerMockRecorder struct {
mock *MockSigner
}
// NewMockSigner creates a new mock instance
func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
mock := &MockSigner{ctrl: ctrl}
mock.recorder = &MockSignerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
return m.recorder
}
// SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignAccessToken indicates an expected call of SignAccessToken
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
}
// SignIDToken mocks base method
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignIDToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignIDToken indicates an expected call of SignIDToken
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
}
// SignatureAlgorithm mocks base method
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
return ret0
}
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
}

170
pkg/op/mock/storage.mock.go Normal file
View file

@ -0,0 +1,170 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: Storage)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc"
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
go_jose_v2 "gopkg.in/square/go-jose.v2"
reflect "reflect"
)
// MockStorage is a mock of Storage interface
type MockStorage struct {
ctrl *gomock.Controller
recorder *MockStorageMockRecorder
}
// MockStorageMockRecorder is the mock recorder for MockStorage
type MockStorageMockRecorder struct {
mock *MockStorage
}
// NewMockStorage creates a new mock instance
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
mock := &MockStorage{ctrl: ctrl}
mock.recorder = &MockStorageMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
// AuthRequestByID mocks base method
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByID indicates an expected call of AuthRequestByID
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
}
// AuthorizeClientIDSecret mocks base method
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
}
// CreateAuthRequest mocks base method
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateAuthRequest indicates an expected call of CreateAuthRequest
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
}
// DeleteAuthRequest mocks base method
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
}
// GetClientByClientID mocks base method
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
ret0, _ := ret[0].(op.Client)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetClientByClientID indicates an expected call of GetClientByClientID
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
}
// GetKeySet mocks base method
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
}
// GetSigningKey mocks base method
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSigningKey indicates an expected call of GetSigningKey
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
}
// GetUserinfoFromScopes mocks base method
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
ret0, _ := ret[0].(*oidc.Userinfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
}
// SaveKeyPair mocks base method
func (m *MockStorage) SaveKeyPair(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveKeyPair", arg0)
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SaveKeyPair indicates an expected call of SaveKeyPair
func (mr *MockStorageMockRecorder) SaveKeyPair(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveKeyPair), arg0)
}

View file

@ -0,0 +1,142 @@
package mock
import (
"context"
"errors"
"testing"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/golang/mock/gomock"
"github.com/caos/oidc/pkg/op"
)
func NewStorage(t *testing.T) op.Storage {
return NewMockStorage(gomock.NewController(t))
}
func NewMockStorageExpectValidClientID(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectValidClientID(m)
return m
}
func NewMockStorageExpectInvalidClientID(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectInvalidClientID(m)
return m
}
func NewMockStorageAny(t *testing.T) op.Storage {
m := NewStorage(t)
mockS := m.(*MockStorage)
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
return m
}
func NewMockStorageSigningKeyError(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectSigningKeyError(m)
return m
}
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectSigningKeyInvalid(m)
return m
}
func NewMockStorageSigningKey(t *testing.T) op.Storage {
m := NewStorage(t)
ExpectSigningKey(m)
return m
}
func ExpectInvalidClientID(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).Return(nil, errors.New("client not found"))
}
func ExpectValidClientID(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, id string) (op.Client, error) {
var appType op.ApplicationType
var authMethod op.AuthMethod
var accessTokenType op.AccessTokenType
switch id {
case "web_client":
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeBearer
case "native_client":
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeBearer
case "useragent_client":
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeJWT
}
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
})
}
func ExpectSigningKeyError(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(nil, errors.New("error"))
}
func ExpectSigningKeyInvalid(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{}, nil)
}
func ExpectSigningKey(s op.Storage) {
mockS := s.(*MockStorage)
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil)
}
type ConfClient struct {
id string
appType op.ApplicationType
authMethod op.AuthMethod
accessTokenType op.AccessTokenType
}
func (c *ConfClient) RedirectURIs() []string {
return []string{
"https://registered.com/callback",
"http://registered.com/callback",
"http://localhost:9999/callback",
"custom://callback",
}
}
func (c *ConfClient) LoginURL(id string) string {
return "login?id=" + id
}
func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.appType
}
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}
func (c *ConfClient) GetID() string {
return c.id
}
func (c *ConfClient) AccessTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) IDTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType
}

51
pkg/op/op.go Normal file
View file

@ -0,0 +1,51 @@
package op
import (
"context"
"net/http"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/caos/oidc/pkg/oidc"
)
type OpenIDProvider interface {
Configuration
HandleDiscovery(w http.ResponseWriter, r *http.Request)
HandleAuthorize(w http.ResponseWriter, r *http.Request)
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
HandleExchange(w http.ResponseWriter, r *http.Request)
HandleUserinfo(w http.ResponseWriter, r *http.Request)
HandleKeys(w http.ResponseWriter, r *http.Request)
HttpHandler() *http.Server
}
func CreateRouter(o OpenIDProvider) *mux.Router {
router := mux.NewRouter()
router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize)
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", o.HandleAuthorizeCallback)
router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange)
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
return router
}
func Start(ctx context.Context, o OpenIDProvider) {
go func() {
<-ctx.Done()
err := o.HttpHandler().Shutdown(ctx)
if err != nil {
logrus.Error("graceful shutdown of oidc server failed")
}
}()
go func() {
err := o.HttpHandler().ListenAndServe()
if err != nil {
logrus.Panicf("oidc server serve failed: %v", err)
}
}()
logrus.Infof("oidc server is listening on %s", o.Port())
}

13
pkg/op/session.go Normal file
View file

@ -0,0 +1,13 @@
package op
import "github.com/caos/oidc/pkg/oidc"
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
if authRequest == nil {
return true
}
if authRequest.Prompt == oidc.PromptNone {
return true
}
return false
}

78
pkg/op/signer.go Normal file
View file

@ -0,0 +1,78 @@
package op
import (
"encoding/json"
"golang.org/x/net/context"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
type Signer interface {
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
SignatureAlgorithm() jose.SignatureAlgorithm
}
type idTokenSigner struct {
signer jose.Signer
storage AuthStorage
algorithm jose.SignatureAlgorithm
}
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
s := &idTokenSigner{
storage: storage,
}
if err := s.initialize(ctx); err != nil {
return nil, err
}
return s, nil
}
func (s *idTokenSigner) initialize(ctx context.Context) error {
var key *jose.SigningKey
var err error
key, err = s.storage.GetSigningKey(ctx)
if err != nil {
key, err = s.storage.SaveKeyPair(ctx)
if err != nil {
return err
}
}
s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{})
if err != nil {
return err
}
s.algorithm = key.Algorithm
return nil
}
func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {
return "", err
}
return result.CompactSerialize()
}
func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
return s.algorithm
}

95
pkg/op/signer_test.go Normal file
View file

@ -0,0 +1,95 @@
package op
import (
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
// func TestNewDefaultSigner(t *testing.T) {
// type args struct {
// storage Storage
// }
// tests := []struct {
// name string
// args args
// want Signer
// wantErr bool
// }{
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyError(t)},
// nil,
// true,
// },
// {
// "err initialize storage fails",
// args{mock.NewMockStorageSigningKeyInvalid(t)},
// nil,
// true,
// },
// {
// "initialize ok",
// args{mock.NewMockStorageSigningKey(t)},
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
// false,
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, err := op.NewDefaultSigner(tt.args.storage)
// if (err != nil) != tt.wantErr {
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
// return
// }
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_idTokenSigner_Sign(t *testing.T) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
require.NoError(t, err)
type fields struct {
signer jose.Signer
storage Storage
}
type args struct {
payload []byte
}
tests := []struct {
name string
fields fields
args args
want string
wantErr bool
}{
{
"ok",
fields{signer, nil},
args{[]byte("test")},
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &idTokenSigner{
signer: tt.fields.signer,
storage: tt.fields.storage,
}
got, err := s.Sign(tt.args.payload)
if (err != nil) != tt.wantErr {
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
}
})
}
}

48
pkg/op/storage.go Normal file
View file

@ -0,0 +1,48 @@
package op
import (
"context"
"time"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
)
type AuthStorage interface {
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(context.Context, string) (AuthRequest, error)
DeleteAuthRequest(context.Context, string) error
GetSigningKey(context.Context) (*jose.SigningKey, error)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
SaveKeyPair(context.Context) (*jose.SigningKey, error)
}
type OPStorage interface {
GetClientByClientID(context.Context, string) (Client, error)
AuthorizeClientIDSecret(context.Context, string, string) error
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
}
type Storage interface {
AuthStorage
OPStorage
}
type AuthRequest interface {
GetID() string
GetACR() string
GetAMR() []string
GetAudience() []string
GetAuthTime() time.Time
GetClientID() string
GetCodeChallenge() *oidc.CodeChallenge
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetScopes() []string
GetState() string
GetSubject() string
Done() bool
}

93
pkg/op/token.go Normal file
View file

@ -0,0 +1,93 @@
package op
import (
"time"
"github.com/caos/oidc/pkg/oidc"
)
type TokenCreator interface {
Issuer() string
Signer() Signer
Storage() Storage
Crypto() Crypto
}
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
var accessToken string
if createAccessToken {
var err error
accessToken, err = CreateAccessToken(authReq, client, creator)
if err != nil {
return nil, err
}
}
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
if err != nil {
return nil, err
}
exp := uint64(client.AccessTokenLifetime().Seconds())
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}, nil
}
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
if client.AccessTokenType() == AccessTokenTypeJWT {
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
}
return CreateBearerToken(authReq, creator.Crypto())
}
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
now := time.Now().UTC()
nbf := now
exp := now.Add(client.AccessTokenLifetime())
claims := &oidc.AccessTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
}
return signer.SignAccessToken(claims)
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}

151
pkg/op/tokenrequest.go Normal file
View file

@ -0,0 +1,151 @@
package op
import (
"context"
"errors"
"net/http"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type Exchanger interface {
Issuer() string
Storage() Storage
Decoder() *schema.Decoder
Signer() Signer
Crypto() Crypto
AuthMethodPostSupported() bool
}
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
if err != nil {
ExchangeRequestError(w, r, err)
}
if tokenReq.Code == "" {
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
return
}
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
utils.MarshalJSON(w, resp)
}
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("error parsing form")
}
tokenReq := new(oidc.AccessTokenRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest("error decoding form")
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
tokenReq.ClientID = clientID
tokenReq.ClientSecret = clientSecret
}
return tokenReq, nil
}
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
if err != nil {
return nil, nil, err
}
if client.GetID() != authReq.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
}
return authReq, client, nil
}
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil {
return nil, nil, err
}
if client.GetAuthMethod() == AuthMethodNone {
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger)
return authReq, client, err
}
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("basic not supported")
}
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, err
}
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil {
return nil, nil, ErrInvalidRequest("invalid code")
}
return authReq, client, nil
}
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
}
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
if !oidc.VerifyCodeChallenge(authReq.GetCodeChallenge(), tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid")
}
return authReq, nil
}
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(ctx, id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
return errors.New("Unimplemented") //TODO: impl
}

28
pkg/op/userinfo.go Normal file
View file

@ -0,0 +1,28 @@
package op
import (
"net/http"
"github.com/caos/oidc/pkg/utils"
)
type UserinfoProvider interface {
Storage() Storage
}
func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) {
scopes, err := ScopesFromAccessToken(w, r)
if err != nil {
return
}
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
if err != nil {
utils.MarshalJSON(w, err)
return
}
utils.MarshalJSON(w, info)
}
func ScopesFromAccessToken(w http.ResponseWriter, r *http.Request) ([]string, error) {
return []string{}, nil
}