refactor and add access types
This commit is contained in:
parent
be6737328c
commit
42099c8207
12 changed files with 250 additions and 77 deletions
|
@ -160,17 +160,21 @@ func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Clie
|
|||
}
|
||||
var appType op.ApplicationType
|
||||
var authMethod op.AuthMethod
|
||||
var accessTokenType op.AccessTokenType
|
||||
if id == "web" {
|
||||
appType = op.ApplicationTypeWeb
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
} else if id == "native" {
|
||||
appType = op.ApplicationTypeNative
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
} else {
|
||||
appType = op.ApplicationTypeUserAgent
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeJWT
|
||||
}
|
||||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod}, nil
|
||||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
||||
|
@ -205,6 +209,7 @@ type ConfClient struct {
|
|||
applicationType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
ID string
|
||||
accessTokenType op.AccessTokenType
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetID() string {
|
||||
|
@ -233,3 +238,13 @@ func (c *ConfClient) ApplicationType() op.ApplicationType {
|
|||
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||
return c.accessTokenType
|
||||
}
|
||||
|
|
|
@ -89,6 +89,15 @@ type Tokens struct {
|
|||
IDToken string
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
IssuedAt time.Time
|
||||
NotBefore time.Time
|
||||
}
|
||||
|
||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||
if err != nil {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/schema"
|
||||
|
@ -36,13 +35,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
return
|
||||
}
|
||||
authReq := new(oidc.AuthRequest)
|
||||
|
||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
validation := ValidateAuthRequest
|
||||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||
validation = validater.ValidateAuthRequest
|
||||
|
@ -51,13 +48,11 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||
|
@ -157,46 +152,44 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
}
|
||||
|
||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
var callback string
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||
if authReq.GetState() != "" {
|
||||
callback = callback + "&state=" + authReq.GetState()
|
||||
}
|
||||
} else {
|
||||
var accessToken string
|
||||
var err error
|
||||
var exp uint64
|
||||
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
||||
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
resp := &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
AuthResponseCode(w, r, authReq, authorizer)
|
||||
return
|
||||
}
|
||||
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||
if authReq.GetState() != "" {
|
||||
callback = callback + "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
|
||||
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package op
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ApplicationTypeWeb ApplicationType = iota
|
||||
ApplicationTypeUserAgent
|
||||
ApplicationTypeNative
|
||||
|
||||
AccessTokenTypeBearer AccessTokenType = iota
|
||||
AccessTokenTypeJWT
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
|
@ -12,6 +17,9 @@ type Client interface {
|
|||
ApplicationType() ApplicationType
|
||||
GetAuthMethod() AuthMethod
|
||||
LoginURL(string) string
|
||||
AccessTokenType() AccessTokenType
|
||||
AccessTokenLifetime() time.Duration
|
||||
IDTokenLifetime() time.Duration
|
||||
}
|
||||
|
||||
func IsConfidentialType(c Client) bool {
|
||||
|
@ -21,3 +29,5 @@ func IsConfidentialType(c Client) bool {
|
|||
type ApplicationType int
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
type AccessTokenType int
|
||||
|
|
|
@ -2,6 +2,8 @@ package op
|
|||
|
||||
import "testing"
|
||||
|
||||
import "os"
|
||||
|
||||
func TestValidateIssuer(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
|
@ -54,7 +56,7 @@ func TestValidateIssuer(t *testing.T) {
|
|||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -65,3 +67,28 @@ func TestValidateIssuer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
os.Setenv("CAOS_OIDC_DEV", "")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,6 +70,9 @@ type Sig struct{}
|
|||
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return jose.HS256
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
|
@ -33,6 +34,34 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// AccessTokenLifetime mocks base method
|
||||
func (m *MockClient) AccessTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
|
||||
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
|
||||
}
|
||||
|
||||
// AccessTokenType mocks base method
|
||||
func (m *MockClient) AccessTokenType() op.AccessTokenType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenType")
|
||||
ret0, _ := ret[0].(op.AccessTokenType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenType indicates an expected call of AccessTokenType
|
||||
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||
}
|
||||
|
||||
// ApplicationType mocks base method
|
||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -75,6 +104,20 @@ func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||
}
|
||||
|
||||
// IDTokenLifetime mocks base method
|
||||
func (m *MockClient) IDTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenLifetime indicates an expected call of IDTokenLifetime
|
||||
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
|
||||
}
|
||||
|
||||
// LoginURL mocks base method
|
||||
func (m *MockClient) LoginURL(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -34,6 +34,21 @@ func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// SignAccessToken mocks base method
|
||||
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignAccessToken indicates an expected call of SignAccessToken
|
||||
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
|
||||
}
|
||||
|
||||
// SignIDToken mocks base method
|
||||
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
|
@ -64,18 +65,22 @@ func ExpectValidClientID(s op.Storage) {
|
|||
func(_ context.Context, id string) (op.Client, error) {
|
||||
var appType op.ApplicationType
|
||||
var authMethod op.AuthMethod
|
||||
var accessTokenType op.AccessTokenType
|
||||
switch id {
|
||||
case "web_client":
|
||||
appType = op.ApplicationTypeWeb
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "native_client":
|
||||
appType = op.ApplicationTypeNative
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "useragent_client":
|
||||
appType = op.ApplicationTypeUserAgent
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeJWT
|
||||
}
|
||||
return &ConfClient{id: id, appType: appType, authMethod: authMethod}, nil
|
||||
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -95,9 +100,10 @@ func ExpectSigningKey(s op.Storage) {
|
|||
}
|
||||
|
||||
type ConfClient struct {
|
||||
id string
|
||||
appType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
id string
|
||||
appType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
accessTokenType op.AccessTokenType
|
||||
}
|
||||
|
||||
func (c *ConfClient) RedirectURIs() []string {
|
||||
|
@ -124,3 +130,13 @@ func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
|||
func (c *ConfClient) GetID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||
return c.accessTokenType
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
type Signer interface {
|
||||
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
||||
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
|
||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
|
@ -56,6 +57,14 @@ func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
|||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
|
||||
result, err := s.signer.Sign(payload)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,17 +1,64 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
|
||||
var err error
|
||||
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
|
||||
exp := time.Duration(5 * time.Minute)
|
||||
return accessToken, uint64(exp.Seconds()), err
|
||||
type TokenCreator interface {
|
||||
Issuer() string
|
||||
Signer() Signer
|
||||
Storage() Storage
|
||||
Crypto() Crypto
|
||||
}
|
||||
|
||||
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
|
||||
var accessToken string
|
||||
if createAccessToken {
|
||||
var err error
|
||||
accessToken, err = CreateAccessToken(authReq, client, creator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exp := uint64(client.AccessTokenLifetime().Seconds())
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
|
||||
if client.AccessTokenType() == AccessTokenTypeJWT {
|
||||
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
|
||||
}
|
||||
return CreateBearerToken(authReq, creator.Crypto())
|
||||
}
|
||||
|
||||
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
||||
|
||||
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
nbf := now
|
||||
exp := now.Add(client.AccessTokenLifetime())
|
||||
claims := &oidc.AccessTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: now,
|
||||
NotBefore: nbf,
|
||||
}
|
||||
return signer.SignAccessToken(claims)
|
||||
}
|
||||
|
||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||
|
|
|
@ -31,35 +31,21 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
||||
return
|
||||
}
|
||||
|
||||
authReq, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
|
||||
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}
|
||||
utils.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
|
@ -82,18 +68,18 @@ func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.Ac
|
|||
return tokenReq, nil
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.GetID() != authReq.GetClientID() {
|
||||
return nil, ErrInvalidRequest("invalid auth code")
|
||||
return nil, nil, ErrInvalidRequest("invalid auth code")
|
||||
}
|
||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||
return nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||
}
|
||||
return authReq, nil
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue