refactor and add access types

This commit is contained in:
Livio Amstutz 2020-01-28 14:29:25 +01:00
parent be6737328c
commit 42099c8207
12 changed files with 250 additions and 77 deletions

View file

@ -160,17 +160,21 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie
}
var appType op.ApplicationType
var authMethod op.AuthMethod
var accessTokenType op.AccessTokenType
if id == "web" {
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeBearer
} else if id == "native" {
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeBearer
} else {
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeJWT
}
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
}
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
@ -205,6 +209,7 @@ type ConfClient struct {
applicationType op.ApplicationType
authMethod op.AuthMethod
ID string
accessTokenType op.AccessTokenType
}
func (c *ConfClient) GetID() string {
@ -233,3 +238,13 @@ func (c *ConfClient) ApplicationType() op.ApplicationType {
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
return c.authMethod
}
func (c *ConfClient) AccessTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) IDTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType
}

View file

@ -89,6 +89,15 @@ type Tokens struct {
IDToken string
}
type AccessTokenClaims struct {
Issuer string
Subject string
Audiences []string
Expiration time.Time
IssuedAt time.Time
NotBefore time.Time
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
if err != nil {

View file

@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/schema"
@ -36,13 +35,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
return
}
authReq := new(oidc.AuthRequest)
err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return
}
validation := ValidateAuthRequest
if validater, ok := authorizer.(ValidationAuthorizer); ok {
validation = validater.ValidateAuthRequest
@ -51,13 +48,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder())
@ -157,46 +152,44 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
}
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode {
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
if authReq.GetState() != "" {
callback = callback + "&state=" + authReq.GetState()
}
} else {
var accessToken string
var err error
var exp uint64
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
}
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
if err != nil {
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
AuthResponseCode(w, r, authReq, authorizer)
return
}
AuthResponseToken(w, r, authReq, authorizer, client)
return
}
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
if authReq.GetState() != "" {
callback = callback + "&state=" + authReq.GetState()
}
http.Redirect(w, r, callback, http.StatusFound)
}
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
http.Redirect(w, r, callback, http.StatusFound)
}

View file

@ -1,9 +1,14 @@
package op
import "time"
const (
ApplicationTypeWeb ApplicationType = iota
ApplicationTypeUserAgent
ApplicationTypeNative
AccessTokenTypeBearer AccessTokenType = iota
AccessTokenTypeJWT
)
type Client interface {
@ -12,6 +17,9 @@ type Client interface {
ApplicationType() ApplicationType
GetAuthMethod() AuthMethod
LoginURL(string) string
AccessTokenType() AccessTokenType
AccessTokenLifetime() time.Duration
IDTokenLifetime() time.Duration
}
func IsConfidentialType(c Client) bool {
@ -21,3 +29,5 @@ func IsConfidentialType(c Client) bool {
type ApplicationType int
type AuthMethod string
type AccessTokenType int

View file

@ -2,6 +2,8 @@ package op
import "testing"
import "os"
func TestValidateIssuer(t *testing.T) {
type args struct {
issuer string
@ -54,7 +56,7 @@ func TestValidateIssuer(t *testing.T) {
{
"localhost with http ok",
args{"http://localhost:9999"},
false,
true,
},
}
for _, tt := range tests {
@ -65,3 +67,28 @@ func TestValidateIssuer(t *testing.T) {
})
}
}
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
type args struct {
issuer string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"localhost with http ok",
args{"http://localhost:9999"},
false,
},
}
os.Setenv("CAOS_OIDC_DEV", "")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -70,6 +70,9 @@ type Sig struct{}
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
return "", nil
}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256
}

View file

@ -8,6 +8,7 @@ import (
op "github.com/caos/oidc/pkg/op"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockClient is a mock of Client interface
@ -33,6 +34,34 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder {
return m.recorder
}
// AccessTokenLifetime mocks base method
func (m *MockClient) AccessTokenLifetime() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
}
// AccessTokenType mocks base method
func (m *MockClient) AccessTokenType() op.AccessTokenType {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenType")
ret0, _ := ret[0].(op.AccessTokenType)
return ret0
}
// AccessTokenType indicates an expected call of AccessTokenType
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
// ApplicationType mocks base method
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()
@ -75,6 +104,20 @@ func (mr *MockClientMockRecorder) GetID() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
}
// IDTokenLifetime mocks base method
func (m *MockClient) IDTokenLifetime() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
// IDTokenLifetime indicates an expected call of IDTokenLifetime
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
}
// LoginURL mocks base method
func (m *MockClient) LoginURL(arg0 string) string {
m.ctrl.T.Helper()

View file

@ -34,6 +34,21 @@ func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
return m.recorder
}
// SignAccessToken mocks base method
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SignAccessToken indicates an expected call of SignAccessToken
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
}
// SignIDToken mocks base method
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
m.ctrl.T.Helper()

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
"time"
"gopkg.in/square/go-jose.v2"
@ -64,18 +65,22 @@ func ExpectValidClientID(s op.Storage) {
func(_ context.Context, id string) (op.Client, error) {
var appType op.ApplicationType
var authMethod op.AuthMethod
var accessTokenType op.AccessTokenType
switch id {
case "web_client":
appType = op.ApplicationTypeWeb
authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeBearer
case "native_client":
appType = op.ApplicationTypeNative
authMethod = op.AuthMethodNone
accessTokenType = op.AccessTokenTypeBearer
case "useragent_client":
appType = op.ApplicationTypeUserAgent
authMethod = op.AuthMethodBasic
accessTokenType = op.AccessTokenTypeJWT
}
return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
})
}
@ -95,9 +100,10 @@ func ExpectSigningKey(s op.Storage) {
}
type ConfClient struct {
id string
appType op.ApplicationType
authMethod op.AuthMethod
id string
appType op.ApplicationType
authMethod op.AuthMethod
accessTokenType op.AccessTokenType
}
func (c *ConfClient) RedirectURIs() []string {
@ -124,3 +130,13 @@ func (c *ConfClient) GetAuthMethod() op.AuthMethod {
func (c *ConfClient) GetID() string {
return c.id
}
func (c *ConfClient) AccessTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) IDTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute)
}
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType
}

View file

@ -11,6 +11,7 @@ import (
type Signer interface {
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
SignatureAlgorithm() jose.SignatureAlgorithm
}
@ -56,6 +57,14 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error)
return s.Sign(payload)
}
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
return s.Sign(payload)
}
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
result, err := s.signer.Sign(payload)
if err != nil {

View file

@ -1,17 +1,64 @@
package op
import (
"fmt"
"time"
"github.com/caos/oidc/pkg/oidc"
)
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
var err error
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
exp := time.Duration(5 * time.Minute)
return accessToken, uint64(exp.Seconds()), err
type TokenCreator interface {
Issuer() string
Signer() Signer
Storage() Storage
Crypto() Crypto
}
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
var accessToken string
if createAccessToken {
var err error
accessToken, err = CreateAccessToken(authReq, client, creator)
if err != nil {
return nil, err
}
}
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
if err != nil {
return nil, err
}
exp := uint64(client.AccessTokenLifetime().Seconds())
return &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}, nil
}
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
if client.AccessTokenType() == AccessTokenTypeJWT {
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
}
return CreateBearerToken(authReq, creator.Crypto())
}
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
now := time.Now().UTC()
nbf := now
exp := now.Add(client.AccessTokenLifetime())
claims := &oidc.AccessTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: now,
NotBefore: nbf,
}
return signer.SignAccessToken(claims)
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {

View file

@ -31,35 +31,21 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
return
}
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
utils.MarshalJSON(w, resp)
}
@ -82,18 +68,18 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
return tokenReq, nil
}
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
if err != nil {
return nil, err
return nil, nil, err
}
if client.GetID() != authReq.GetClientID() {
return nil, ErrInvalidRequest("invalid auth code")
return nil, nil, ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, ErrInvalidRequest("redirect_uri does no correspond")
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
}
return authReq, nil
return authReq, client, nil
}
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {