This commit is contained in:
Livio Amstutz 2019-12-17 10:03:09 +01:00
parent 3d276c59b4
commit d3d9e676c0
9 changed files with 126 additions and 34 deletions

View file

@ -16,9 +16,7 @@ type AuthStorage struct {
key *rsa.PrivateKey key *rsa.PrivateKey
} }
type OPStorage struct{} func NewAuthStorage() op.Storage {
func NewAuthStorage() op.AuthStorage {
reader := rand.Reader reader := rand.Reader
bitSize := 2048 bitSize := 2048
key, err := rsa.GenerateKey(reader, bitSize) key, err := rsa.GenerateKey(reader, bitSize)
@ -106,6 +104,7 @@ func (a *AuthRequest) GetSubject() string {
var ( var (
a = &AuthRequest{} a = &AuthRequest{}
t bool
) )
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
@ -116,15 +115,20 @@ func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthReque
Method: authReq.CodeChallengeMethod, Method: authReq.CodeChallengeMethod,
} }
} }
t = false
return a, nil return a, nil
} }
func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByCode(string) (op.AuthRequest, error) {
return a, nil return a, nil
} }
func (s *AuthStorage) DeleteAuthRequestAndCode(string, string) error { func (s *AuthStorage) DeleteAuthRequest(string) error {
t = true
return nil return nil
} }
func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) { func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
if id != "id" || t {
return nil, errors.New("not found")
}
return a, nil return a, nil
} }
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
@ -142,7 +146,7 @@ func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
}, nil }, nil
} }
func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) { func (s *AuthStorage) GetClientByClientID(id string) (op.Client, error) {
if id == "none" { if id == "none" {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
@ -161,10 +165,11 @@ func (s *OPStorage) GetClientByClientID(id string) (op.Client, error) {
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
} }
func (s *OPStorage) AuthorizeClientIDSecret(id string, _ string) error { func (s *AuthStorage) AuthorizeClientIDSecret(id string, _ string) error {
return nil return nil
} }
func (s *OPStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
func (s *AuthStorage) GetUserinfoFromScopes([]string) (*oidc.Userinfo, error) {
return &oidc.Userinfo{ return &oidc.Userinfo{
Subject: a.GetSubject(), Subject: a.GetSubject(),
Address: &oidc.UserinfoAddress{ Address: &oidc.UserinfoAddress{

View file

@ -2,7 +2,12 @@ package main
import ( import (
"context" "context"
"crypto/sha256"
"html/template"
"log" "log"
"net/http"
"github.com/gorilla/mux"
"github.com/caos/oidc/example/internal/mock" "github.com/caos/oidc/example/internal/mock"
"github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op"
@ -11,17 +16,50 @@ import (
func main() { func main() {
ctx := context.Background() ctx := context.Background()
config := &op.Config{ config := &op.Config{
Issuer: "http://localhost:9998/", Issuer: "http://localhost:9998/",
CryptoKey: sha256.Sum256([]byte("test")),
Port: "9998", Port: "9998",
} }
authStorage := mock.NewAuthStorage() storage := mock.NewAuthStorage()
opStorage := &mock.OPStorage{} handler, err := op.NewDefaultOP(config, storage, op.WithCustomTokenEndpoint("test"))
handler, err := op.NewDefaultOP(config, authStorage, opStorage, op.WithCustomTokenEndpoint("test"))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
router := handler.HttpHandler().Handler.(*mux.Router)
router.Methods("GET").Path("/login").HandlerFunc(HandleLogin)
router.Methods("POST").Path("/login").HandlerFunc(HandleCallback)
op.Start(ctx, handler) op.Start(ctx, handler)
<-ctx.Done() <-ctx.Done()
}
func HandleLogin(w http.ResponseWriter, r *http.Request) {
tpl := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Login</title>
</head>
<body>
<form method="POST" action="/login">
<input name="client"/>
<button type="submit">Login</button>
</form>
</body>
</html>`
t, err := template.New("login").Parse(tpl)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = t.Execute(w, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func HandleCallback(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
client := r.FormValue("client")
http.Redirect(w, r, "/authorize/"+client, http.StatusFound)
} }

1
go.mod
View file

@ -3,6 +3,7 @@ module github.com/caos/oidc
go 1.13 go 1.13
require ( require (
github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21
github.com/golang/mock v1.3.1 github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.3.2 // indirect
github.com/google/uuid v1.1.1 github.com/google/uuid v1.1.1

3
go.sum
View file

@ -1,4 +1,7 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/caos/utils v0.0.0-20191210140001-db9d0ce57f21 h1:SZTPN44SgN2/pMYDOzxmsG3kk7IaioLI8ujgk5Atp5M=
github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21 h1:BtzwMln/KEyQCQ0n/pzInTkBzw3fYUW3x+8EodCbXEo=
github.com/caos/utils/crypto v0.0.0-20191210140001-db9d0ce57f21/go.mod h1:X4Iy86UaICQ1ocpN/DfZHfXyaeXqBg8+Y6lYBOH9kN8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -18,6 +18,7 @@ type Authorizer interface {
Decoder() *schema.Decoder Decoder() *schema.Decoder
Encoder() *schema.Encoder Encoder() *schema.Encoder
Signer() Signer Signer() Signer
Crypto() Crypto
Issuer() string Issuer() string
} }
@ -152,7 +153,12 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode()) code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
} else { } else {
var accessToken string var accessToken string
var err error var err error
@ -160,12 +166,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer()) accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
} }
} }
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer()) idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
} }
resp := &oidc.AccessTokenResponse{ resp := &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
@ -175,9 +183,14 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri
} }
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
} }
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}

24
pkg/op/crypto.go Normal file
View file

@ -0,0 +1,24 @@
package op
import "github.com/caos/utils/crypto"
type Crypto interface {
Encrypt(string) (string, error)
Decrypt(string) (string, error)
}
type aesCrypto struct {
key string
}
func NewAESCrypto(key [32]byte) Crypto {
return &aesCrypto{key: string(key[:32])}
}
func (c *aesCrypto) Encrypt(s string) (string, error) {
return crypto.EncryptAES(s, c.key)
}
func (c *aesCrypto) Decrypt(s string) (string, error) {
return crypto.DecryptAES(s, c.key)
}

View file

@ -39,6 +39,7 @@ type DefaultOP struct {
discoveryConfig *oidc.DiscoveryConfiguration discoveryConfig *oidc.DiscoveryConfiguration
storage Storage storage Storage
signer Signer signer Signer
crypto Crypto
http *http.Server http *http.Server
decoder *schema.Decoder decoder *schema.Decoder
encoder *schema.Encoder encoder *schema.Encoder
@ -47,6 +48,7 @@ type DefaultOP struct {
type Config struct { type Config struct {
Issuer string Issuer string
IDTokenValidity time.Duration IDTokenValidity time.Duration
CryptoKey [32]byte
// ScopesSupported: oidc.SupportedScopes, // ScopesSupported: oidc.SupportedScopes,
// ResponseTypesSupported: responseTypes, // ResponseTypesSupported: responseTypes,
// GrantTypesSupported: oidc.SupportedGrantTypes, // GrantTypesSupported: oidc.SupportedGrantTypes,
@ -99,27 +101,19 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
} }
} }
func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { func NewDefaultOP(config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
err := ValidateIssuer(config.Issuer) err := ValidateIssuer(config.Issuer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
storage := struct {
AuthStorage
OPStorage
}{
AuthStorage: authStorage,
OPStorage: opStorage,
}
p := &DefaultOP{ p := &DefaultOP{
config: config, config: config,
storage: storage, storage: storage,
endpoints: DefaultEndpoints, endpoints: DefaultEndpoints,
} }
p.signer, err = NewDefaultSigner(authStorage) p.signer, err = NewDefaultSigner(storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -142,6 +136,8 @@ func NewDefaultOP(config *Config, authStorage AuthStorage, opStorage OPStorage,
p.encoder = schema.NewEncoder() p.encoder = schema.NewEncoder()
p.crypto = NewAESCrypto(config.CryptoKey)
return p, nil return p, nil
} }
@ -197,6 +193,10 @@ func (p *DefaultOP) Signer() Signer {
return p.signer return p.signer
} }
func (p *DefaultOP) Crypto() Crypto {
return p.crypto
}
func (p *DefaultOP) IDTokenValidity() time.Duration { func (p *DefaultOP) IDTokenValidity() time.Duration {
if p.config.IDTokenValidity == 0 { if p.config.IDTokenValidity == 0 {
p.config.IDTokenValidity = DefaultIDTokenValidity p.config.IDTokenValidity = DefaultIDTokenValidity

View file

@ -11,8 +11,7 @@ import (
type AuthStorage interface { type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error) AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(string) (AuthRequest, error) DeleteAuthRequest(string) error
DeleteAuthRequestAndCode(string, string) error
GetSigningKey() (*jose.SigningKey, error) GetSigningKey() (*jose.SigningKey, error)
GetKeySet() (*jose.JSONWebKeySet, error) GetKeySet() (*jose.JSONWebKeySet, error)

View file

@ -17,6 +17,7 @@ type Exchanger interface {
Storage() Storage Storage() Storage
Decoder() *schema.Decoder Decoder() *schema.Decoder
Signer() Signer Signer() Signer
Crypto() Crypto
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
} }
@ -36,7 +37,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return return
} }
err = exchanger.Storage().DeleteAuthRequestAndCode(authReq.GetID(), tokenReq.Code) err = exchanger.Storage().DeleteAuthRequest(authReq.GetID())
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
@ -116,9 +117,9 @@ func AuthorizeClient(tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Au
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) authReq, err := AuthRequestByCode(tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, ErrInvalidRequest("invalid code")
} }
return authReq, client, nil return authReq, client, nil
} }
@ -131,7 +132,7 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
if tokenReq.CodeVerifier == "" { if tokenReq.CodeVerifier == "" {
return nil, ErrInvalidRequest("code_challenge required") return nil, ErrInvalidRequest("code_challenge required")
} }
authReq, err := storage.AuthRequestByCode(tokenReq.Code) authReq, err := AuthRequestByCode(tokenReq.Code, nil, storage)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid code") return nil, ErrInvalidRequest("invalid code")
} }
@ -141,6 +142,14 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, storage AuthStora
return authReq, nil return authReq, nil
} }
func AuthRequestByCode(code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
id, err := crypto.Decrypt(code)
if err != nil {
return nil, err
}
return storage.AuthRequestByID(id)
}
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r) tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil { if err != nil {