refactoring

This commit is contained in:
Livio Amstutz 2019-12-06 10:42:17 +01:00
parent a793e77679
commit 310220d38e
17 changed files with 346 additions and 149 deletions

View file

@ -3,7 +3,6 @@ package op
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@ -19,6 +18,7 @@ type Authorizer interface {
Decoder() *schema.Decoder
Encoder() *schema.Encoder
Signer() Signer
Issuer() string
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
}
@ -37,7 +37,7 @@ type ValidationAuthorizer interface {
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm()
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"))
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
// AuthRequestError(w, r, nil, )
return
}
@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)))
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return
}
@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
validation = validater.ValidateAuthRequest
}
if err := validation(authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err)
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(authReq)
if err != nil {
AuthRequestError(w, r, authReq, err)
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err)
AuthRequestError(w, r, req, err, authorizer.Encoder())
return
}
RedirectToLogin(req.GetID(), client, w, r)
@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil
}
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error {
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" {
return ErrInvalidRequest("redirect_uri must not be empty")
}
@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
authReq, err := authorizer.Storage().AuthRequestByID(id)
if err != nil {
AuthRequestError(w, r, nil, err)
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r)
@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
} else {
var accessToken string
var err error
var exp uint64
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, err = CreateAccessToken()
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil {
}
}
idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer())
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil {
}
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: "Bearer",
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
values := make(map[string][]string)
authorizer.Encoder().Encode(resp, values)
v := url.Values(values)
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
}
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
}
http.Redirect(w, r, callback, http.StatusFound)
}

View file

@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
uri string
clientID string
responseType oidc.ResponseType
storage op.Storage
storage op.OPStorage
}
tests := []struct {
name string

View file

@ -18,6 +18,8 @@ const (
authMethodBasic = "client_secret_basic"
authMethodPost = "client_secret_post"
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
)
var (
@ -28,7 +30,6 @@ var (
Userinfo: defaultUserinfoEndpoint,
JwksURI: defaultKeysEndpoint,
}
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
)
type DefaultOP struct {
@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request)
}
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}

View file

@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
// ClaimsSupported: oidc.SupportedClaims,
IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
SubjectTypesSupported: subjectTypes(c),
TokenEndpointAuthMethodsSupported: authMethods(c),
TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()),
}
}
@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func authMethods(c Configuration) []string {
func authMethods(basic, post bool) []string {
authMethods := make([]string, 0, 2)
if c.AuthMethodBasicSupported() {
if basic {
// if c.AuthMethodBasicSupported() {
authMethods = append(authMethods, authMethodBasic)
}
if c.AuthMethodPostSupported() {
if post {
// if c.AuthMethodPostSupported() {
authMethods = append(authMethods, authMethodPost)
}
return authMethods

View file

@ -1,4 +1,4 @@
package op_test
package op
import (
"net/http"
@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
)
func TestDiscover(t *testing.T) {
@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.Discover(tt.args.w, tt.args.config)
Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
c op.Configuration
s op.Signer
c Configuration
s Signer
}
tests := []struct {
name string
@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
}
})
}
}
func Test_scopes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("scopes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_responseTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_grantTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
// func Test_sigAlgorithms(t *testing.T) {
// type args struct {
// s Signer
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "",
// args{},
// []string{"RS256"},
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
// }
// })
// }
// }
// func Test_subjectTypes(t *testing.T) {
// type args struct {
// c Configuration
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "none",
// args{func()}
// }
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_authMethods(t *testing.T) {
type args struct {
basic bool
post bool
}
tests := []struct {
name string
args args
want []string
}{
{
"none",
args{false, false},
[]string{},
},
{
"basic",
args{true, false},
[]string{authMethodBasic},
},
{
"post",
args{false, true},
[]string{authMethodPost},
},
{
"basic and post",
args{true, true},
[]string{authMethodBasic, authMethodPost},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) {
t.Errorf("authMethods() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,8 +1,10 @@
package op
import (
"fmt"
"net/http"
"net/url"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
@ -13,6 +15,21 @@ const (
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface {
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) {
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.state = authReq.GetState()
params, err := utils.URLEncodeResponse(e, encoder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?"
url += "?" + params
} else {
url += "#"
}
var errorType errorType
var description string
if e, ok := err.(*OAuthError); ok {
errorType = e.ErrorType
description = e.Description
} else {
errorType = ServerError
description = err.Error()
}
url += "error=" + string(errorType)
if description != "" {
url += "&error_description=" + description
}
if authReq.GetState() != "" {
url += "&state=" + authReq.GetState()
url += "#" + params
}
http.Redirect(w, r, url, http.StatusFound)
}
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
}
type OAuthError struct {
ErrorType errorType `json:"error"`
Description string `json:"description"`
}
var (
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
if authReq == nil {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
if authReq.GetRedirectURI() == "" {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
callback := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback += "?"
} else {
callback += "#"
}
callback += "error=" + string(e.ErrorType)
if e.Description != "" {
callback += "&error_description=" + url.QueryEscape(e.Description)
}
if authReq.GetState() != "" {
callback += "&state=" + authReq.GetState()
}
http.Redirect(w, r, callback, http.StatusFound)
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"description" schema:"description"`
state string `json:"state" schema:"state"`
}
func (e *OAuthError) Error() string {
return ""
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
}

View file

@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
}
// Issuer mocks base method
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
}
// Signer mocks base method
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()

View file

@ -12,14 +12,12 @@ import (
type OpenIDProvider interface {
Configuration
// Storage() Storage
HandleDiscovery(w http.ResponseWriter, r *http.Request)
HandleAuthorize(w http.ResponseWriter, r *http.Request)
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
HandleExchange(w http.ResponseWriter, r *http.Request)
HandleUserinfo(w http.ResponseWriter, r *http.Request)
HandleKeys(w http.ResponseWriter, r *http.Request)
// Storage() Storage
HttpHandler() *http.Server
}

View file

@ -36,9 +36,11 @@ type AuthRequest interface {
GetAudience() []string
GetAuthTime() time.Time
GetClientID() string
GetCode() string
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetScopes() []string
GetState() string
GetSubject() string
}

46
pkg/op/token.go Normal file
View file

@ -0,0 +1,46 @@
package op
import (
"fmt"
"time"
"github.com/caos/oidc/pkg/oidc"
)
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
var err error
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
exp := time.Duration(5 * time.Minute)
return accessToken, uint64(exp.Seconds()), err
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}

View file

@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ExchangeRequestError(w, r, err)
return
}
accessToken, err := CreateAccessToken()
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer())
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
utils.MarshalJSON(w, resp)
}
func CreateAccessToken() (string, error) {
return "accessToken", nil
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() {