This commit is contained in:
Livio Amstutz 2019-12-17 10:03:09 +01:00
parent 3d276c59b4
commit d3d9e676c0
9 changed files with 126 additions and 34 deletions

View file

@ -18,6 +18,7 @@ type Authorizer interface {
Decoder() *schema.Decoder
Encoder() *schema.Encoder
Signer() Signer
Crypto() Crypto
Issuer() string
}
@ -152,7 +153,12 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
} else {
var accessToken string
var err error
@ -160,12 +166,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
}
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
@ -175,9 +183,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
}
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
}
http.Redirect(w, r, callback, http.StatusFound)
}
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}

24
pkg/op/crypto.go Normal file
View file

@ -0,0 +1,24 @@
package op
import "github.com/caos/utils/crypto"
type Crypto interface {
Encrypt(string) (string, error)
Decrypt(string) (string, error)
}
type aesCrypto struct {
key string
}
func NewAESCrypto(key [32]byte) Crypto {
return &aesCrypto{key: string(key[:32])}
}
func (c *aesCrypto) Encrypt(s string) (string, error) {
return crypto.EncryptAES(s, c.key)
}
func (c *aesCrypto) Decrypt(s string) (string, error) {
return crypto.DecryptAES(s, c.key)
}

View file

@ -39,6 +39,7 @@ type DefaultOP struct {
discoveryConfig *oidc.DiscoveryConfiguration
storage Storage
signer Signer
crypto Crypto
http *http.Server
decoder *schema.Decoder
encoder *schema.Encoder
@ -47,6 +48,7 @@ type DefaultOP struct {
type Config struct {
Issuer string
IDTokenValidity time.Duration
CryptoKey [32]byte
// ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes,
@ -99,27 +101,19 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
}
}
func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer)
if err != nil {
return nil, err
}
storage := struct {
AuthStorage
OPStorage
}{
AuthStorage: authStorage,
OPStorage: opStorage,
}
p := &DefaultOP{
config: config,
storage: storage,
endpoints: DefaultEndpoints,
}
p.signer, err = NewDefaultSigner(authStorage)
p.signer, err = NewDefaultSigner(storage)
if err != nil {
return nil, err
}
@ -142,6 +136,8 @@ func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage,
p.encoder = schema.NewEncoder()
p.crypto = NewAESCrypto(config.CryptoKey)
return p, nil
}
@ -197,6 +193,10 @@ func (p *DefaultOP) Signer() Signer {
return p.signer
}
func (p *DefaultOP) Crypto() Crypto {
return p.crypto
}
func (p *DefaultOP) IDTokenValidity() time.Duration {
if p.config.IDTokenValidity == 0 {
p.config.IDTokenValidity = DefaultIDTokenValidity

View file

@ -11,8 +11,7 @@ import (
type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(string) (AuthRequest, error)
DeleteAuthRequestAndCode(string, string) error
DeleteAuthRequest(string) error
GetSigningKey() (*jose.SigningKey, error)
GetKeySet() (*jose.JSONWebKeySet, error)

View file

@ -17,6 +17,7 @@ type Exchanger interface {
Storage() Storage
Decoder() *schema.Decoder
Signer() Signer
Crypto() Crypto
AuthMethodPostSupported() bool
}
@ -36,7 +37,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code)
err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -116,9 +117,9 @@ func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Au
if err != nil {
return nil, nil, err
}
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil {
return nil, nil, err
return nil, nil, ErrInvalidRequest("invalid code")
}
return authReq, client, nil
}
@ -131,7 +132,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required")
}
authReq, err := storage.AuthRequestByCode(tokenReq.Code)
authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
}
@ -141,6 +142,14 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
return authReq, nil
}
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil {