refactor and add access types
This commit is contained in:
parent
be6737328c
commit
42099c8207
12 changed files with 250 additions and 77 deletions
|
@ -160,17 +160,21 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie
|
||||||
}
|
}
|
||||||
var appType op.ApplicationType
|
var appType op.ApplicationType
|
||||||
var authMethod op.AuthMethod
|
var authMethod op.AuthMethod
|
||||||
|
var accessTokenType op.AccessTokenType
|
||||||
if id == "web" {
|
if id == "web" {
|
||||||
appType = op.ApplicationTypeWeb
|
appType = op.ApplicationTypeWeb
|
||||||
authMethod = op.AuthMethodBasic
|
authMethod = op.AuthMethodBasic
|
||||||
|
accessTokenType = op.AccessTokenTypeBearer
|
||||||
} else if id == "native" {
|
} else if id == "native" {
|
||||||
appType = op.ApplicationTypeNative
|
appType = op.ApplicationTypeNative
|
||||||
authMethod = op.AuthMethodNone
|
authMethod = op.AuthMethodNone
|
||||||
|
accessTokenType = op.AccessTokenTypeBearer
|
||||||
} else {
|
} else {
|
||||||
appType = op.ApplicationTypeUserAgent
|
appType = op.ApplicationTypeUserAgent
|
||||||
authMethod = op.AuthMethodNone
|
authMethod = op.AuthMethodNone
|
||||||
|
accessTokenType = op.AccessTokenTypeJWT
|
||||||
}
|
}
|
||||||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
||||||
|
@ -205,6 +209,7 @@ type ConfClient struct {
|
||||||
applicationType op.ApplicationType
|
applicationType op.ApplicationType
|
||||||
authMethod op.AuthMethod
|
authMethod op.AuthMethod
|
||||||
ID string
|
ID string
|
||||||
|
accessTokenType op.AccessTokenType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConfClient) GetID() string {
|
func (c *ConfClient) GetID() string {
|
||||||
|
@ -233,3 +238,13 @@ func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||||
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||||
return c.authMethod
|
return c.authMethod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||||
|
return time.Duration(5 * time.Minute)
|
||||||
|
}
|
||||||
|
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||||
|
return time.Duration(5 * time.Minute)
|
||||||
|
}
|
||||||
|
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||||
|
return c.accessTokenType
|
||||||
|
}
|
||||||
|
|
|
@ -89,6 +89,15 @@ type Tokens struct {
|
||||||
IDToken string
|
IDToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AccessTokenClaims struct {
|
||||||
|
Issuer string
|
||||||
|
Subject string
|
||||||
|
Audiences []string
|
||||||
|
Expiration time.Time
|
||||||
|
IssuedAt time.Time
|
||||||
|
NotBefore time.Time
|
||||||
|
}
|
||||||
|
|
||||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/gorilla/schema"
|
"github.com/gorilla/schema"
|
||||||
|
@ -36,13 +35,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authReq := new(oidc.AuthRequest)
|
authReq := new(oidc.AuthRequest)
|
||||||
|
|
||||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
validation := ValidateAuthRequest
|
validation := ValidateAuthRequest
|
||||||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||||
validation = validater.ValidateAuthRequest
|
validation = validater.ValidateAuthRequest
|
||||||
|
@ -51,13 +48,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||||
|
@ -157,46 +152,44 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||||
var callback string
|
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
}
|
||||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||||
|
AuthResponseCode(w, r, authReq, authorizer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||||
if authReq.GetState() != "" {
|
if authReq.GetState() != "" {
|
||||||
callback = callback + "&state=" + authReq.GetState()
|
callback = callback + "&state=" + authReq.GetState()
|
||||||
}
|
}
|
||||||
} else {
|
http.Redirect(w, r, callback, http.StatusFound)
|
||||||
var accessToken string
|
}
|
||||||
var err error
|
|
||||||
var exp uint64
|
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
|
||||||
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||||
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
|
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
|
|
||||||
if err != nil {
|
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp := &oidc.AccessTokenResponse{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
IDToken: idToken,
|
|
||||||
TokenType: oidc.BearerToken,
|
|
||||||
ExpiresIn: exp,
|
|
||||||
}
|
|
||||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||||
}
|
|
||||||
http.Redirect(w, r, callback, http.StatusFound)
|
http.Redirect(w, r, callback, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ApplicationTypeWeb ApplicationType = iota
|
ApplicationTypeWeb ApplicationType = iota
|
||||||
ApplicationTypeUserAgent
|
ApplicationTypeUserAgent
|
||||||
ApplicationTypeNative
|
ApplicationTypeNative
|
||||||
|
|
||||||
|
AccessTokenTypeBearer AccessTokenType = iota
|
||||||
|
AccessTokenTypeJWT
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
|
@ -12,6 +17,9 @@ type Client interface {
|
||||||
ApplicationType() ApplicationType
|
ApplicationType() ApplicationType
|
||||||
GetAuthMethod() AuthMethod
|
GetAuthMethod() AuthMethod
|
||||||
LoginURL(string) string
|
LoginURL(string) string
|
||||||
|
AccessTokenType() AccessTokenType
|
||||||
|
AccessTokenLifetime() time.Duration
|
||||||
|
IDTokenLifetime() time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsConfidentialType(c Client) bool {
|
func IsConfidentialType(c Client) bool {
|
||||||
|
@ -21,3 +29,5 @@ func IsConfidentialType(c Client) bool {
|
||||||
type ApplicationType int
|
type ApplicationType int
|
||||||
|
|
||||||
type AuthMethod string
|
type AuthMethod string
|
||||||
|
|
||||||
|
type AccessTokenType int
|
||||||
|
|
|
@ -2,6 +2,8 @@ package op
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
func TestValidateIssuer(t *testing.T) {
|
func TestValidateIssuer(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
issuer string
|
issuer string
|
||||||
|
@ -54,7 +56,7 @@ func TestValidateIssuer(t *testing.T) {
|
||||||
{
|
{
|
||||||
"localhost with http ok",
|
"localhost with http ok",
|
||||||
args{"http://localhost:9999"},
|
args{"http://localhost:9999"},
|
||||||
false,
|
true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -65,3 +67,28 @@ func TestValidateIssuer(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
issuer string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"localhost with http ok",
|
||||||
|
args{"http://localhost:9999"},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
os.Setenv("CAOS_OIDC_DEV", "")
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -70,6 +70,9 @@ type Sig struct{}
|
||||||
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
return jose.HS256
|
return jose.HS256
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
op "github.com/caos/oidc/pkg/op"
|
op "github.com/caos/oidc/pkg/op"
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
time "time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockClient is a mock of Client interface
|
// MockClient is a mock of Client interface
|
||||||
|
@ -33,6 +34,34 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccessTokenLifetime mocks base method
|
||||||
|
func (m *MockClient) AccessTokenLifetime() time.Duration {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AccessTokenLifetime")
|
||||||
|
ret0, _ := ret[0].(time.Duration)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
|
||||||
|
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessTokenType mocks base method
|
||||||
|
func (m *MockClient) AccessTokenType() op.AccessTokenType {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "AccessTokenType")
|
||||||
|
ret0, _ := ret[0].(op.AccessTokenType)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessTokenType indicates an expected call of AccessTokenType
|
||||||
|
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||||
|
}
|
||||||
|
|
||||||
// ApplicationType mocks base method
|
// ApplicationType mocks base method
|
||||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
@ -75,6 +104,20 @@ func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IDTokenLifetime mocks base method
|
||||||
|
func (m *MockClient) IDTokenLifetime() time.Duration {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "IDTokenLifetime")
|
||||||
|
ret0, _ := ret[0].(time.Duration)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDTokenLifetime indicates an expected call of IDTokenLifetime
|
||||||
|
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
|
||||||
|
}
|
||||||
|
|
||||||
// LoginURL mocks base method
|
// LoginURL mocks base method
|
||||||
func (m *MockClient) LoginURL(arg0 string) string {
|
func (m *MockClient) LoginURL(arg0 string) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -34,6 +34,21 @@ func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignAccessToken mocks base method
|
||||||
|
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignAccessToken indicates an expected call of SignAccessToken
|
||||||
|
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// SignIDToken mocks base method
|
// SignIDToken mocks base method
|
||||||
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
|
@ -64,18 +65,22 @@ func ExpectValidClientID(s op.Storage) {
|
||||||
func(_ context.Context, id string) (op.Client, error) {
|
func(_ context.Context, id string) (op.Client, error) {
|
||||||
var appType op.ApplicationType
|
var appType op.ApplicationType
|
||||||
var authMethod op.AuthMethod
|
var authMethod op.AuthMethod
|
||||||
|
var accessTokenType op.AccessTokenType
|
||||||
switch id {
|
switch id {
|
||||||
case "web_client":
|
case "web_client":
|
||||||
appType = op.ApplicationTypeWeb
|
appType = op.ApplicationTypeWeb
|
||||||
authMethod = op.AuthMethodBasic
|
authMethod = op.AuthMethodBasic
|
||||||
|
accessTokenType = op.AccessTokenTypeBearer
|
||||||
case "native_client":
|
case "native_client":
|
||||||
appType = op.ApplicationTypeNative
|
appType = op.ApplicationTypeNative
|
||||||
authMethod = op.AuthMethodNone
|
authMethod = op.AuthMethodNone
|
||||||
|
accessTokenType = op.AccessTokenTypeBearer
|
||||||
case "useragent_client":
|
case "useragent_client":
|
||||||
appType = op.ApplicationTypeUserAgent
|
appType = op.ApplicationTypeUserAgent
|
||||||
authMethod = op.AuthMethodBasic
|
authMethod = op.AuthMethodBasic
|
||||||
|
accessTokenType = op.AccessTokenTypeJWT
|
||||||
}
|
}
|
||||||
return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
|
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +103,7 @@ type ConfClient struct {
|
||||||
id string
|
id string
|
||||||
appType op.ApplicationType
|
appType op.ApplicationType
|
||||||
authMethod op.AuthMethod
|
authMethod op.AuthMethod
|
||||||
|
accessTokenType op.AccessTokenType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConfClient) RedirectURIs() []string {
|
func (c *ConfClient) RedirectURIs() []string {
|
||||||
|
@ -124,3 +130,13 @@ func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||||
func (c *ConfClient) GetID() string {
|
func (c *ConfClient) GetID() string {
|
||||||
return c.id
|
return c.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||||
|
return time.Duration(5 * time.Minute)
|
||||||
|
}
|
||||||
|
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||||
|
return time.Duration(5 * time.Minute)
|
||||||
|
}
|
||||||
|
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||||
|
return c.accessTokenType
|
||||||
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
type Signer interface {
|
type Signer interface {
|
||||||
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
||||||
|
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
|
||||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,6 +57,14 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
||||||
return s.Sign(payload)
|
return s.Sign(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
|
||||||
|
payload, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return s.Sign(payload)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
|
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
|
||||||
result, err := s.signer.Sign(payload)
|
result, err := s.signer.Sign(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,17 +1,64 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
|
type TokenCreator interface {
|
||||||
|
Issuer() string
|
||||||
|
Signer() Signer
|
||||||
|
Storage() Storage
|
||||||
|
Crypto() Crypto
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
|
||||||
|
var accessToken string
|
||||||
|
if createAccessToken {
|
||||||
var err error
|
var err error
|
||||||
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
|
accessToken, err = CreateAccessToken(authReq, client, creator)
|
||||||
exp := time.Duration(5 * time.Minute)
|
if err != nil {
|
||||||
return accessToken, uint64(exp.Seconds()), err
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
exp := uint64(client.AccessTokenLifetime().Seconds())
|
||||||
|
return &oidc.AccessTokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
IDToken: idToken,
|
||||||
|
TokenType: oidc.BearerToken,
|
||||||
|
ExpiresIn: exp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
|
||||||
|
if client.AccessTokenType() == AccessTokenTypeJWT {
|
||||||
|
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
|
||||||
|
}
|
||||||
|
return CreateBearerToken(authReq, creator.Crypto())
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||||
|
return crypto.Encrypt(authReq.GetID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
nbf := now
|
||||||
|
exp := now.Add(client.AccessTokenLifetime())
|
||||||
|
claims := &oidc.AccessTokenClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: authReq.GetSubject(),
|
||||||
|
Audiences: authReq.GetAudience(),
|
||||||
|
Expiration: exp,
|
||||||
|
IssuedAt: now,
|
||||||
|
NotBefore: nbf,
|
||||||
|
}
|
||||||
|
return signer.SignAccessToken(claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||||
|
|
|
@ -31,35 +31,21 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||||
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
|
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
|
|
||||||
if err != nil {
|
|
||||||
ExchangeRequestError(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &oidc.AccessTokenResponse{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
IDToken: idToken,
|
|
||||||
TokenType: oidc.BearerToken,
|
|
||||||
ExpiresIn: exp,
|
|
||||||
}
|
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,18 +68,18 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
|
||||||
return tokenReq, nil
|
return tokenReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if client.GetID() != authReq.GetClientID() {
|
if client.GetID() != authReq.GetClientID() {
|
||||||
return nil, ErrInvalidRequest("invalid auth code")
|
return nil, nil, ErrInvalidRequest("invalid auth code")
|
||||||
}
|
}
|
||||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||||
return nil, ErrInvalidRequest("redirect_uri does no correspond")
|
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||||
}
|
}
|
||||||
return authReq, nil
|
return authReq, client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue