From 42099c82076ddf648b378511b15481b3ea41cfb5 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Tue, 28 Jan 2020 14:29:25 +0100 Subject: [PATCH] refactor and add access types --- example/internal/mock/storage.go | 17 +++++- pkg/oidc/token.go | 9 ++++ pkg/op/authrequest.go | 81 +++++++++++++---------------- pkg/op/client.go | 10 ++++ pkg/op/config_test.go | 29 ++++++++++- pkg/op/mock/authorizer.mock.impl.go | 3 ++ pkg/op/mock/client.mock.go | 43 +++++++++++++++ pkg/op/mock/signer.mock.go | 15 ++++++ pkg/op/mock/storage.mock.impl.go | 24 +++++++-- pkg/op/signer.go | 9 ++++ pkg/op/token.go | 59 ++++++++++++++++++--- pkg/op/tokenrequest.go | 28 +++------- 12 files changed, 250 insertions(+), 77 deletions(-) diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 3581224..37f47b1 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -160,17 +160,21 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie } var appType op.ApplicationType var authMethod op.AuthMethod + var accessTokenType op.AccessTokenType if id == "web" { appType = op.ApplicationTypeWeb authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeBearer } else if id == "native" { appType = op.ApplicationTypeNative authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeBearer } else { appType = op.ApplicationTypeUserAgent authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeJWT } - return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil + return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil } func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error { @@ -205,6 +209,7 @@ type ConfClient struct { applicationType op.ApplicationType authMethod op.AuthMethod ID string + accessTokenType op.AccessTokenType } func (c *ConfClient) GetID() string { @@ -233,3 +238,13 @@ func (c *ConfClient) ApplicationType() op.ApplicationType { func (c *ConfClient) GetAuthMethod() op.AuthMethod { return c.authMethod } + +func (c *ConfClient) AccessTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) IDTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index b041c34..2a52a23 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -89,6 +89,15 @@ type Tokens struct { IDToken string } +type AccessTokenClaims struct { + Issuer string + Subject string + Audiences []string + Expiration time.Time + IssuedAt time.Time + NotBefore time.Time +} + func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go index a63a620..9f9505d 100644 --- a/pkg/op/authrequest.go +++ b/pkg/op/authrequest.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "strings" - "time" "github.com/gorilla/mux" "github.com/gorilla/schema" @@ -36,13 +35,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { return } authReq := new(oidc.AuthRequest) - err = authorizer.Decoder().Decode(authReq, r.Form) if err != nil { AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder()) return } - validation := ValidateAuthRequest if validater, ok := authorizer.(ValidationAuthorizer); ok { validation = validater.ValidateAuthRequest @@ -51,13 +48,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) if err != nil { AuthRequestError(w, r, req, err, authorizer.Encoder()) @@ -157,46 +152,44 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author } func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { - var callback string - if authReq.GetResponseType() == oidc.ResponseTypeCode { - code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) - if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return - } - callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) - if authReq.GetState() != "" { - callback = callback + "&state=" + authReq.GetState() - } - } else { - var accessToken string - var err error - var exp uint64 - if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { - accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer()) - if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return - } - } - idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer()) - if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return - } - resp := &oidc.AccessTokenResponse{ - AccessToken: accessToken, - IDToken: idToken, - TokenType: oidc.BearerToken, - ExpiresIn: exp, - } - params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) - if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return - } - callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) + client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) + if err != nil { + } + if authReq.GetResponseType() == oidc.ResponseTypeCode { + AuthResponseCode(w, r, authReq, authorizer) + return + } + AuthResponseToken(w, r, authReq, authorizer, client) + return +} + +func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { + code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) + if authReq.GetState() != "" { + callback = callback + "&state=" + authReq.GetState() + } + http.Redirect(w, r, callback, http.StatusFound) +} + +func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { + createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly + resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "") + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) http.Redirect(w, r, callback, http.StatusFound) } diff --git a/pkg/op/client.go b/pkg/op/client.go index cbd69fb..33c30a4 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -1,9 +1,14 @@ package op +import "time" + const ( ApplicationTypeWeb ApplicationType = iota ApplicationTypeUserAgent ApplicationTypeNative + + AccessTokenTypeBearer AccessTokenType = iota + AccessTokenTypeJWT ) type Client interface { @@ -12,6 +17,9 @@ type Client interface { ApplicationType() ApplicationType GetAuthMethod() AuthMethod LoginURL(string) string + AccessTokenType() AccessTokenType + AccessTokenLifetime() time.Duration + IDTokenLifetime() time.Duration } func IsConfidentialType(c Client) bool { @@ -21,3 +29,5 @@ func IsConfidentialType(c Client) bool { type ApplicationType int type AuthMethod string + +type AccessTokenType int diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index b5f508b..56cf2eb 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -2,6 +2,8 @@ package op import "testing" +import "os" + func TestValidateIssuer(t *testing.T) { type args struct { issuer string @@ -54,7 +56,7 @@ func TestValidateIssuer(t *testing.T) { { "localhost with http ok", args{"http://localhost:9999"}, - false, + true, }, } for _, tt := range tests { @@ -65,3 +67,28 @@ func TestValidateIssuer(t *testing.T) { }) } } + +func TestValidateIssuerDevLocalAllowed(t *testing.T) { + type args struct { + issuer string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "localhost with http ok", + args{"http://localhost:9999"}, + false, + }, + } + os.Setenv("CAOS_OIDC_DEV", "") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { + t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index e90a11e..0091877 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -70,6 +70,9 @@ type Sig struct{} func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { return "", nil } +func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) { + return "", nil +} func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { return jose.HS256 } diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index e856860..9ae2201 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -8,6 +8,7 @@ import ( op "github.com/caos/oidc/pkg/op" gomock "github.com/golang/mock/gomock" reflect "reflect" + time "time" ) // MockClient is a mock of Client interface @@ -33,6 +34,34 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } +// AccessTokenLifetime mocks base method +func (m *MockClient) AccessTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// AccessTokenLifetime indicates an expected call of AccessTokenLifetime +func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime)) +} + +// AccessTokenType mocks base method +func (m *MockClient) AccessTokenType() op.AccessTokenType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenType") + ret0, _ := ret[0].(op.AccessTokenType) + return ret0 +} + +// AccessTokenType indicates an expected call of AccessTokenType +func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) +} + // ApplicationType mocks base method func (m *MockClient) ApplicationType() op.ApplicationType { m.ctrl.T.Helper() @@ -75,6 +104,20 @@ func (mr *MockClientMockRecorder) GetID() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID)) } +// IDTokenLifetime mocks base method +func (m *MockClient) IDTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// IDTokenLifetime indicates an expected call of IDTokenLifetime +func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime)) +} + // LoginURL mocks base method func (m *MockClient) LoginURL(arg0 string) string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index d9f6613..5c7b669 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -34,6 +34,21 @@ func (m *MockSigner) EXPECT() *MockSignerMockRecorder { return m.recorder } +// SignAccessToken mocks base method +func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SignAccessToken", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignAccessToken indicates an expected call of SignAccessToken +func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0) +} + // SignIDToken mocks base method func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 74ded21..7cd62b9 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" "gopkg.in/square/go-jose.v2" @@ -64,18 +65,22 @@ func ExpectValidClientID(s op.Storage) { func(_ context.Context, id string) (op.Client, error) { var appType op.ApplicationType var authMethod op.AuthMethod + var accessTokenType op.AccessTokenType switch id { case "web_client": appType = op.ApplicationTypeWeb authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeBearer case "native_client": appType = op.ApplicationTypeNative authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeBearer case "useragent_client": appType = op.ApplicationTypeUserAgent authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeJWT } - return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil + return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil }) } @@ -95,9 +100,10 @@ func ExpectSigningKey(s op.Storage) { } type ConfClient struct { - id string - appType op.ApplicationType - authMethod op.AuthMethod + id string + appType op.ApplicationType + authMethod op.AuthMethod + accessTokenType op.AccessTokenType } func (c *ConfClient) RedirectURIs() []string { @@ -124,3 +130,13 @@ func (c *ConfClient) GetAuthMethod() op.AuthMethod { func (c *ConfClient) GetID() string { return c.id } + +func (c *ConfClient) AccessTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) IDTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 1aa9619..6235931 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -11,6 +11,7 @@ import ( type Signer interface { SignIDToken(claims *oidc.IDTokenClaims) (string, error) + SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) SignatureAlgorithm() jose.SignatureAlgorithm } @@ -56,6 +57,14 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) return s.Sign(payload) } +func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + return s.Sign(payload) +} + func (s *idTokenSigner) Sign(payload []byte) (string, error) { result, err := s.signer.Sign(payload) if err != nil { diff --git a/pkg/op/token.go b/pkg/op/token.go index fd759b2..7c1dedc 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -1,17 +1,64 @@ package op import ( - "fmt" "time" "github.com/caos/oidc/pkg/oidc" ) -func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) { - var err error - accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes()) - exp := time.Duration(5 * time.Minute) - return accessToken, uint64(exp.Seconds()), err +type TokenCreator interface { + Issuer() string + Signer() Signer + Storage() Storage + Crypto() Crypto +} + +func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) { + var accessToken string + if createAccessToken { + var err error + accessToken, err = CreateAccessToken(authReq, client, creator) + if err != nil { + return nil, err + } + } + idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer()) + if err != nil { + return nil, err + } + exp := uint64(client.AccessTokenLifetime().Seconds()) + return &oidc.AccessTokenResponse{ + AccessToken: accessToken, + IDToken: idToken, + TokenType: oidc.BearerToken, + ExpiresIn: exp, + }, nil +} + +func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) { + if client.AccessTokenType() == AccessTokenTypeJWT { + return CreateJWT(creator.Issuer(), authReq, client, creator.Signer()) + } + return CreateBearerToken(authReq, creator.Crypto()) +} + +func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) { + return crypto.Encrypt(authReq.GetID()) +} + +func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) { + now := time.Now().UTC() + nbf := now + exp := now.Add(client.AccessTokenLifetime()) + claims := &oidc.AccessTokenClaims{ + Issuer: issuer, + Subject: authReq.GetSubject(), + Audiences: authReq.GetAudience(), + Expiration: exp, + IssuedAt: now, + NotBefore: nbf, + } + return signer.SignAccessToken(claims) } func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) { diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 8b8ad61..cf32432 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -31,35 +31,21 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ExchangeRequestError(w, r, ErrInvalidRequest("code missing")) return } - - authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) + authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { ExchangeRequestError(w, r, err) return } - err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID()) if err != nil { ExchangeRequestError(w, r, err) return } - accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer()) + resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code) if err != nil { ExchangeRequestError(w, r, err) return } - idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer()) - if err != nil { - ExchangeRequestError(w, r, err) - return - } - - resp := &oidc.AccessTokenResponse{ - AccessToken: accessToken, - IDToken: idToken, - TokenType: oidc.BearerToken, - ExpiresIn: exp, - } utils.MarshalJSON(w, resp) } @@ -82,18 +68,18 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac return tokenReq, nil } -func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { +func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger) if err != nil { - return nil, err + return nil, nil, err } if client.GetID() != authReq.GetClientID() { - return nil, ErrInvalidRequest("invalid auth code") + return nil, nil, ErrInvalidRequest("invalid auth code") } if tokenReq.RedirectURI != authReq.GetRedirectURI() { - return nil, ErrInvalidRequest("redirect_uri does no correspond") + return nil, nil, ErrInvalidRequest("redirect_uri does no correspond") } - return authReq, nil + return authReq, client, nil } func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {