refactoring

This commit is contained in:
Livio Amstutz 2019-12-06 10:42:17 +01:00
parent a793e77679
commit 310220d38e
17 changed files with 346 additions and 149 deletions

View file

@ -41,7 +41,9 @@ func (a *AuthRequest) GetACR() string {
}
func (a *AuthRequest) GetAMR() []string {
return []string{}
return []string{
"password",
}
}
func (a *AuthRequest) GetAudience() []string {
@ -55,7 +57,11 @@ func (a *AuthRequest) GetAuthTime() time.Time {
}
func (a *AuthRequest) GetClientID() string {
return ""
return a.ID
}
func (a *AuthRequest) GetCode() string {
return "code"
}
func (a *AuthRequest) GetID() string {
@ -63,23 +69,31 @@ func (a *AuthRequest) GetID() string {
}
func (a *AuthRequest) GetNonce() string {
return ""
return "nonce"
}
func (a *AuthRequest) GetRedirectURI() string {
return ""
return "http://localhost:5556/auth/callback"
}
func (a *AuthRequest) GetResponseType() oidc.ResponseType {
return a.ResponseType
}
func (a *AuthRequest) GetScopes() []string {
return []string{
"openid",
"profile",
"email",
}
}
func (a *AuthRequest) GetState() string {
return ""
}
func (a *AuthRequest) GetSubject() string {
return ""
return "sub"
}
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
@ -132,11 +146,14 @@ func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
}
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
return s.key, nil
}
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
pubkey := s.key.Public()
return jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256"},
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
},
}, nil
}
@ -151,6 +168,9 @@ func (c *ConfClient) RedirectURIs() []string {
"http://localhost:9999/callback",
"http://localhost:5556/auth/callback",
"custom://callback",
"https://localhost:8443/test/a/instructions-example/callback",
"https://op.certification.openid.net:62054/authz_cb",
"https://op.certification.openid.net:62054/authz_post",
}
}

View file

@ -25,6 +25,8 @@ const (
PromptSelectAccount = "select_account"
GrantTypeCode GrantType = "authorization_code"
BearerToken = "Bearer"
)
var displayValues = map[string]Display{

View file

@ -14,17 +14,18 @@ import (
)
type IDTokenClaims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audiences []string `json:"aud,omitempty"`
Expiration time.Time `json:"exp,omitempty"`
IssuedAt time.Time `json:"iat,omitempty"`
AuthTime time.Time `json:"auth_time,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
Issuer string
Subject string
Audiences []string
Expiration time.Time
IssuedAt time.Time
AuthTime time.Time
Nonce string
AuthenticationContextClassReference string
AuthenticationMethodsReferences []string
AuthorizedParty string
AccessTokenHash string
CodeHash string
Signature jose.SignatureAlgorithm //TODO: ???
}
@ -46,6 +47,7 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
t.AuthorizedParty = i.AuthorizedParty
t.AccessTokenHash = i.AccessTokenHash
t.CodeHash = i.CodeHash
return nil
}
@ -63,6 +65,7 @@ func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
AuthorizedParty: t.AuthorizedParty,
AccessTokenHash: t.AccessTokenHash,
CodeHash: t.CodeHash,
}
return json.Marshal(j)
}
@ -81,21 +84,23 @@ type jsonIDToken struct {
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
}
type Tokens struct {
*oauth2.Token
IDTokenClaims *IDTokenClaims
IDToken string
}
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
tokenHash, err := getHashAlgorithm(sigAlgorithm)
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := getHashAlgorithm(sigAlgorithm)
if err != nil {
return "", err
}
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
hash.Write([]byte(claim)) // hash documents that Write will never return an error
sum := hash.Sum(nil)[:hash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum), nil
}

View file

@ -3,7 +3,6 @@ package op
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@ -19,6 +18,7 @@ type Authorizer interface {
Decoder() *schema.Decoder
Encoder() *schema.Encoder
Signer() Signer
Issuer() string
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
}
@ -37,7 +37,7 @@ type ValidationAuthorizer interface {
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm()
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"))
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
// AuthRequestError(w, r, nil, )
return
}
@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)))
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return
}
@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
validation = validater.ValidateAuthRequest
}
if err := validation(authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err)
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(authReq)
if err != nil {
AuthRequestError(w, r, authReq, err)
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err)
AuthRequestError(w, r, req, err, authorizer.Encoder())
return
}
RedirectToLogin(req.GetID(), client, w, r)
@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil
}
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error {
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" {
return ErrInvalidRequest("redirect_uri must not be empty")
}
@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
authReq, err := authorizer.Storage().AuthRequestByID(id)
if err != nil {
AuthRequestError(w, r, nil, err)
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r)
@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
} else {
var accessToken string
var err error
var exp uint64
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, err = CreateAccessToken()
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil {
}
}
idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer())
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil {
}
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: "Bearer",
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
values := make(map[string][]string)
authorizer.Encoder().Encode(resp, values)
v := url.Values(values)
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
if err != nil {
}
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
}
http.Redirect(w, r, callback, http.StatusFound)
}

View file

@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
uri string
clientID string
responseType oidc.ResponseType
storage op.Storage
storage op.OPStorage
}
tests := []struct {
name string

View file

@ -18,6 +18,8 @@ const (
authMethodBasic = "client_secret_basic"
authMethodPost = "client_secret_post"
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
)
var (
@ -28,7 +30,6 @@ var (
Userinfo: defaultUserinfoEndpoint,
JwksURI: defaultKeysEndpoint,
}
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
)
type DefaultOP struct {
@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request)
}
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}

View file

@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
// ClaimsSupported: oidc.SupportedClaims,
IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
SubjectTypesSupported: subjectTypes(c),
TokenEndpointAuthMethodsSupported: authMethods(c),
TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()),
}
}
@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
}
func authMethods(c Configuration) []string {
func authMethods(basic, post bool) []string {
authMethods := make([]string, 0, 2)
if c.AuthMethodBasicSupported() {
if basic {
// if c.AuthMethodBasicSupported() {
authMethods = append(authMethods, authMethodBasic)
}
if c.AuthMethodPostSupported() {
if post {
// if c.AuthMethodPostSupported() {
authMethods = append(authMethods, authMethodPost)
}
return authMethods

View file

@ -1,4 +1,4 @@
package op_test
package op
import (
"net/http"
@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
)
func TestDiscover(t *testing.T) {
@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.Discover(tt.args.w, tt.args.config)
Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
c op.Configuration
s op.Signer
c Configuration
s Signer
}
tests := []struct {
name string
@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
}
})
}
}
func Test_scopes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("scopes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_responseTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_grantTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
// func Test_sigAlgorithms(t *testing.T) {
// type args struct {
// s Signer
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "",
// args{},
// []string{"RS256"},
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
// }
// })
// }
// }
// func Test_subjectTypes(t *testing.T) {
// type args struct {
// c Configuration
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "none",
// args{func()}
// }
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_authMethods(t *testing.T) {
type args struct {
basic bool
post bool
}
tests := []struct {
name string
args args
want []string
}{
{
"none",
args{false, false},
[]string{},
},
{
"basic",
args{true, false},
[]string{authMethodBasic},
},
{
"post",
args{false, true},
[]string{authMethodPost},
},
{
"basic and post",
args{true, true},
[]string{authMethodBasic, authMethodPost},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) {
t.Errorf("authMethods() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,8 +1,10 @@
package op
import (
"fmt"
"net/http"
"net/url"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
@ -13,6 +15,21 @@ const (
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface {
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) {
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.state = authReq.GetState()
params, err := utils.URLEncodeResponse(e, encoder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?"
url += "?" + params
} else {
url += "#"
}
var errorType errorType
var description string
if e, ok := err.(*OAuthError); ok {
errorType = e.ErrorType
description = e.Description
} else {
errorType = ServerError
description = err.Error()
}
url += "error=" + string(errorType)
if description != "" {
url += "&error_description=" + description
}
if authReq.GetState() != "" {
url += "&state=" + authReq.GetState()
url += "#" + params
}
http.Redirect(w, r, url, http.StatusFound)
}
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
}
type OAuthError struct {
ErrorType errorType `json:"error"`
Description string `json:"description"`
}
var (
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
if authReq == nil {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
if authReq.GetRedirectURI() == "" {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
callback := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback += "?"
} else {
callback += "#"
}
callback += "error=" + string(e.ErrorType)
if e.Description != "" {
callback += "&error_description=" + url.QueryEscape(e.Description)
}
if authReq.GetState() != "" {
callback += "&state=" + authReq.GetState()
}
http.Redirect(w, r, callback, http.StatusFound)
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"description" schema:"description"`
state string `json:"state" schema:"state"`
}
func (e *OAuthError) Error() string {
return ""
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
}

View file

@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
}
// Issuer mocks base method
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
}
// Signer mocks base method
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()

View file

@ -12,14 +12,12 @@ import (
type OpenIDProvider interface {
Configuration
// Storage() Storage
HandleDiscovery(w http.ResponseWriter, r *http.Request)
HandleAuthorize(w http.ResponseWriter, r *http.Request)
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
HandleExchange(w http.ResponseWriter, r *http.Request)
HandleUserinfo(w http.ResponseWriter, r *http.Request)
HandleKeys(w http.ResponseWriter, r *http.Request)
// Storage() Storage
HttpHandler() *http.Server
}

View file

@ -36,9 +36,11 @@ type AuthRequest interface {
GetAudience() []string
GetAuthTime() time.Time
GetClientID() string
GetCode() string
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetScopes() []string
GetState() string
GetSubject() string
}

46
pkg/op/token.go Normal file
View file

@ -0,0 +1,46 @@
package op
import (
"fmt"
"time"
"github.com/caos/oidc/pkg/oidc"
)
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
var err error
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
exp := time.Duration(5 * time.Minute)
return accessToken, uint64(exp.Seconds()), err
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}

View file

@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ExchangeRequestError(w, r, err)
return
}
accessToken, err := CreateAccessToken()
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
if err != nil {
ExchangeRequestError(w, r, err)
return
}
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer())
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
resp := &oidc.AccessTokenResponse{
AccessToken: accessToken,
IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
}
utils.MarshalJSON(w, resp)
}
func CreateAccessToken() (string, error) {
return "accessToken", nil
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() {

View file

@ -64,7 +64,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
}
if p.verifier == nil {
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
}
return p, nil
@ -110,6 +110,7 @@ func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
//handling the oauth2 code exchange, extracting and validating the id_token
//returning it paresed together with the oauth2 tokens (access, refresh)
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
token, err := p.oauthConfig.Exchange(ctx, code)
if err != nil {
return nil, err //TODO: our error
@ -124,7 +125,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc
return nil, err //TODO: err
}
return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
}
//AuthURL is the `RelayingParty` interface implementation

View file

@ -443,7 +443,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
return nil //TODO: return error
}
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil {
return err
}

View file

@ -55,3 +55,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e
}
return nil
}
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
values := make(map[string][]string)
err := encoder.Encode(resp, values)
if err != nil {
return "", err
}
v := url.Values(values)
return v.Encode(), nil
}