From d3d9e676c09b2d084d4a884ec2ed86682d843874 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 17 Dec 2019 10:03:09 +0100 Subject: [PATCH] crypto --- example/internal/mock/storage.go | 19 ++++++----- example/server/default/default.go | 52 ++++++++++++++++++++++++++----- go.mod | 1 + go.sum | 3 ++ pkg/op/authrequest.go | 21 ++++++++++--- pkg/op/crypto.go | 24 ++++++++++++++ pkg/op/default_op.go | 20 ++++++------ pkg/op/storage.go | 3 +- pkg/op/tokenrequest.go | 17 +++++++--- 9 files changed, 126 insertions(+), 34 deletions(-) create mode 100644 pkg/op/crypto.go diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 690f9e2..e3c8f33 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -16,9 +16,7 @@ type AuthStorage struct { key *rsa.PrivateKey } -type OPStorage struct{} - -func NewAuthStorage() op.AuthStorage { +func NewAuthStorage() op.Storage { reader := rand.Reader bitSize := 2048 key, err := rsa.GenerateKey(reader, bitSize) @@ -106,6 +104,7 @@ func (a *AuthRequest) GetSubject() string { var ( a = &AuthRequest{} + t bool ) func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { @@ -116,15 +115,20 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque Method: authReq.CodeChallengeMethod, } } + t = false return a, nil } func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) { return a, nil } -func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error { +func (s *AuthStorage) DeleteAuthRequest(string) error { + t = true return nil } func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { + if id != "id" || t { + return nil, errors.New("not found") + } return a, nil } func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { @@ -142,7 +146,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { }, nil } -func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { +func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) { if id == "none" { return nil, errors.New("not found") } @@ -161,10 +165,11 @@ func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil } -func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error { +func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error { return nil } -func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) { + +func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) { return &oidc.Userinfo{ Subject: a.GetSubject(), Address: &oidc.UserinfoAddress{ diff --git a/example/server/default/default.go b/example/server/default/default.go index af61a45..3ad6feb 100644 --- a/example/server/default/default.go +++ b/example/server/default/default.go @@ -2,7 +2,12 @@ package main import ( "context" + "crypto/sha256" + "html/template" "log" + "net/http" + + "github.com/gorilla/mux" "github.com/caos/oidc/example/internal/mock" "github.com/caos/oidc/pkg/op" @@ -11,17 +16,50 @@ import ( func main() { ctx := context.Background() config := &op.Config{ - Issuer: "http://localhost:9998/", - - Port: "9998", + Issuer: "http://localhost:9998/", + CryptoKey: sha256.Sum256([]byte("test")), + Port: "9998", } - authStorage := mock.NewAuthStorage() - opStorage := &mock.OPStorage{} - handler, err := op.NewDefaultOP(config, authStorage, opStorage, op.WithCustomTokenEndpoint("test")) + storage := mock.NewAuthStorage() + handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test")) if err != nil { log.Fatal(err) } + router := handler.HttpHandler().Handler.(*mux.Router) + router.Methods("GET").Path("/login").HandlerFunc(HandleLogin) + router.Methods("POST").Path("/login").HandlerFunc(HandleCallback) op.Start(ctx, handler) <-ctx.Done() - +} + +func HandleLogin(w http.ResponseWriter, r *http.Request) { + tpl := ` + + + + + Login + + +
+ + +
+ + ` + t, err := template.New("login").Parse(tpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = t.Execute(w, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func HandleCallback(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + client := r.FormValue("client") + http.Redirect(w, r, "/authorize/"+client, http.StatusFound) } diff --git a/go.mod b/go.mod index b35882a..bcf9786 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/caos/oidc go 1.13 require ( + github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21 github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.2 // indirect github.com/google/uuid v1.1.1 diff --git a/go.sum b/go.sum index 54a2ca8..b613331 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,7 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/caos/utils v0.0.0-20191210140001-db9d0ce57f21 h1:SZTPN44SgN2/pMYDOzxmsG3kk7IaioLI8ujgk5Atp5M= +github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21 h1:BtzwMln/KEyQCQ0n/pzInTkBzw3fYUW3x+8EodCbXEo= +github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21/go.mod h1:X4Iy86UaICQ1ocpN/DfZHfXyaeXqBg8+Y6lYBOH9kN8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index 43da556..bdfa585 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -18,6 +18,7 @@ type Authorizer interface { Decoder() *schema.Decoder Encoder() *schema.Encoder Signer() Signer + Crypto() Crypto Issuer() string } @@ -152,7 +153,12 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { var callback string if authReq.GetResponseType() == oidc.ResponseTypeCode { - callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode()) + code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) } else { var accessToken string var err error @@ -160,12 +166,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer()) if err != nil { - + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return } } idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer()) if err != nil { - + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return } resp := &oidc.AccessTokenResponse{ AccessToken: accessToken, @@ -175,9 +183,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri } params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) if err != nil { - + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return } callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) } http.Redirect(w, r, callback, http.StatusFound) } + +func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { + return crypto.Encrypt(authReq.GetID()) +} diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go new file mode 100644 index 0000000..420c32f --- /dev/null +++ b/pkg/op/crypto.go @@ -0,0 +1,24 @@ +package op + +import "github.com/caos/utils/crypto" + +type Crypto interface { + Encrypt(string) (string, error) + Decrypt(string) (string, error) +} + +type aesCrypto struct { + key string +} + +func NewAESCrypto(key [32]byte) Crypto { + return &aesCrypto{key: string(key[:32])} +} + +func (c *aesCrypto) Encrypt(s string) (string, error) { + return crypto.EncryptAES(s, c.key) +} + +func (c *aesCrypto) Decrypt(s string) (string, error) { + return crypto.DecryptAES(s, c.key) +} diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go index 783db82..baccc2d 100644 --- a/pkg/op/default_op.go +++ b/pkg/op/default_op.go @@ -39,6 +39,7 @@ type DefaultOP struct { discoveryConfig *oidc.DiscoveryConfiguration storage Storage signer Signer + crypto Crypto http *http.Server decoder *schema.Decoder encoder *schema.Encoder @@ -47,6 +48,7 @@ type DefaultOP struct { type Config struct { Issuer string IDTokenValidity time.Duration + CryptoKey [32]byte // ScopesSupported: oidc.SupportedScopes, // ResponseTypesSupported: responseTypes, // GrantTypesSupported: oidc.SupportedGrantTypes, @@ -99,27 +101,19 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts { } } -func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { +func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { err := ValidateIssuer(config.Issuer) if err != nil { return nil, err } - storage := struct { - AuthStorage - OPStorage - }{ - AuthStorage: authStorage, - OPStorage: opStorage, - } - p := &DefaultOP{ config: config, storage: storage, endpoints: DefaultEndpoints, } - p.signer, err = NewDefaultSigner(authStorage) + p.signer, err = NewDefaultSigner(storage) if err != nil { return nil, err } @@ -142,6 +136,8 @@ func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage, p.encoder = schema.NewEncoder() + p.crypto = NewAESCrypto(config.CryptoKey) + return p, nil } @@ -197,6 +193,10 @@ func (p *DefaultOP) Signer() Signer { return p.signer } +func (p *DefaultOP) Crypto() Crypto { + return p.crypto +} + func (p *DefaultOP) IDTokenValidity() time.Duration { if p.config.IDTokenValidity == 0 { p.config.IDTokenValidity = DefaultIDTokenValidity diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 8ec7aea..0971532 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -11,8 +11,7 @@ import ( type AuthStorage interface { CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) AuthRequestByID(string) (AuthRequest, error) - AuthRequestByCode(string) (AuthRequest, error) - DeleteAuthRequestAndCode(string, string) error + DeleteAuthRequest(string) error GetSigningKey() (*jose.SigningKey, error) GetKeySet() (*jose.JSONWebKeySet, error) diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index f895d8c..935589f 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -17,6 +17,7 @@ type Exchanger interface { Storage() Storage Decoder() *schema.Decoder Signer() Signer + Crypto() Crypto AuthMethodPostSupported() bool } @@ -36,7 +37,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } - err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code) + err = exchanger.Storage().DeleteAuthRequest(authReq.GetID()) if err != nil { ExchangeRequestError(w, r, err) return @@ -116,9 +117,9 @@ func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Au if err != nil { return nil, nil, err } - authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) + authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) if err != nil { - return nil, nil, err + return nil, nil, ErrInvalidRequest("invalid code") } return authReq, client, nil } @@ -131,7 +132,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora if tokenReq.CodeVerifier == "" { return nil, ErrInvalidRequest("code_challenge required") } - authReq, err := storage.AuthRequestByCode(tokenReq.Code) + authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage) if err != nil { return nil, ErrInvalidRequest("invalid code") } @@ -141,6 +142,14 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora return authReq, nil } +func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { + id, err := crypto.Decrypt(code) + if err != nil { + return nil, err + } + return storage.AuthRequestByID(id) +} + func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { tokenRequest, err := ParseTokenExchangeRequest(w, r) if err != nil {