refactoring
This commit is contained in:
parent
a793e77679
commit
310220d38e
17 changed files with 346 additions and 149 deletions
|
@ -41,7 +41,9 @@ func (a *AuthRequest) GetACR() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetAMR() []string {
|
func (a *AuthRequest) GetAMR() []string {
|
||||||
return []string{}
|
return []string{
|
||||||
|
"password",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetAudience() []string {
|
func (a *AuthRequest) GetAudience() []string {
|
||||||
|
@ -55,7 +57,11 @@ func (a *AuthRequest) GetAuthTime() time.Time {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetClientID() string {
|
func (a *AuthRequest) GetClientID() string {
|
||||||
return ""
|
return a.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequest) GetCode() string {
|
||||||
|
return "code"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetID() string {
|
func (a *AuthRequest) GetID() string {
|
||||||
|
@ -63,23 +69,31 @@ func (a *AuthRequest) GetID() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetNonce() string {
|
func (a *AuthRequest) GetNonce() string {
|
||||||
return ""
|
return "nonce"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetRedirectURI() string {
|
func (a *AuthRequest) GetRedirectURI() string {
|
||||||
return ""
|
return "http://localhost:5556/auth/callback"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetResponseType() oidc.ResponseType {
|
func (a *AuthRequest) GetResponseType() oidc.ResponseType {
|
||||||
return a.ResponseType
|
return a.ResponseType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequest) GetScopes() []string {
|
||||||
|
return []string{
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetState() string {
|
func (a *AuthRequest) GetState() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) GetSubject() string {
|
func (a *AuthRequest) GetSubject() string {
|
||||||
return ""
|
return "sub"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||||
|
@ -132,11 +146,14 @@ func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
||||||
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
||||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||||
}
|
}
|
||||||
|
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
||||||
|
return s.key, nil
|
||||||
|
}
|
||||||
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
|
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
|
||||||
pubkey := s.key.Public()
|
pubkey := s.key.Public()
|
||||||
return jose.JSONWebKeySet{
|
return jose.JSONWebKeySet{
|
||||||
Keys: []jose.JSONWebKey{
|
Keys: []jose.JSONWebKey{
|
||||||
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256"},
|
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -151,6 +168,9 @@ func (c *ConfClient) RedirectURIs() []string {
|
||||||
"http://localhost:9999/callback",
|
"http://localhost:9999/callback",
|
||||||
"http://localhost:5556/auth/callback",
|
"http://localhost:5556/auth/callback",
|
||||||
"custom://callback",
|
"custom://callback",
|
||||||
|
"https://localhost:8443/test/a/instructions-example/callback",
|
||||||
|
"https://op.certification.openid.net:62054/authz_cb",
|
||||||
|
"https://op.certification.openid.net:62054/authz_post",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,8 @@ const (
|
||||||
PromptSelectAccount = "select_account"
|
PromptSelectAccount = "select_account"
|
||||||
|
|
||||||
GrantTypeCode GrantType = "authorization_code"
|
GrantTypeCode GrantType = "authorization_code"
|
||||||
|
|
||||||
|
BearerToken = "Bearer"
|
||||||
)
|
)
|
||||||
|
|
||||||
var displayValues = map[string]Display{
|
var displayValues = map[string]Display{
|
||||||
|
|
|
@ -14,17 +14,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDTokenClaims struct {
|
type IDTokenClaims struct {
|
||||||
Issuer string `json:"iss,omitempty"`
|
Issuer string
|
||||||
Subject string `json:"sub,omitempty"`
|
Subject string
|
||||||
Audiences []string `json:"aud,omitempty"`
|
Audiences []string
|
||||||
Expiration time.Time `json:"exp,omitempty"`
|
Expiration time.Time
|
||||||
IssuedAt time.Time `json:"iat,omitempty"`
|
IssuedAt time.Time
|
||||||
AuthTime time.Time `json:"auth_time,omitempty"`
|
AuthTime time.Time
|
||||||
Nonce string `json:"nonce,omitempty"`
|
Nonce string
|
||||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
AuthenticationContextClassReference string
|
||||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
AuthenticationMethodsReferences []string
|
||||||
AuthorizedParty string `json:"azp,omitempty"`
|
AuthorizedParty string
|
||||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
AccessTokenHash string
|
||||||
|
CodeHash string
|
||||||
|
|
||||||
Signature jose.SignatureAlgorithm //TODO: ???
|
Signature jose.SignatureAlgorithm //TODO: ???
|
||||||
}
|
}
|
||||||
|
@ -46,6 +47,7 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
||||||
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
||||||
t.AuthorizedParty = i.AuthorizedParty
|
t.AuthorizedParty = i.AuthorizedParty
|
||||||
t.AccessTokenHash = i.AccessTokenHash
|
t.AccessTokenHash = i.AccessTokenHash
|
||||||
|
t.CodeHash = i.CodeHash
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +65,7 @@ func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
||||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||||
AuthorizedParty: t.AuthorizedParty,
|
AuthorizedParty: t.AuthorizedParty,
|
||||||
AccessTokenHash: t.AccessTokenHash,
|
AccessTokenHash: t.AccessTokenHash,
|
||||||
|
CodeHash: t.CodeHash,
|
||||||
}
|
}
|
||||||
return json.Marshal(j)
|
return json.Marshal(j)
|
||||||
}
|
}
|
||||||
|
@ -81,21 +84,23 @@ type jsonIDToken struct {
|
||||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||||
AuthorizedParty string `json:"azp,omitempty"`
|
AuthorizedParty string `json:"azp,omitempty"`
|
||||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||||
|
CodeHash string `json:"c_hash,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tokens struct {
|
type Tokens struct {
|
||||||
*oauth2.Token
|
*oauth2.Token
|
||||||
IDTokenClaims *IDTokenClaims
|
IDTokenClaims *IDTokenClaims
|
||||||
|
IDToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||||
tokenHash, err := getHashAlgorithm(sigAlgorithm)
|
hash, err := getHashAlgorithm(sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
|
hash.Write([]byte(claim)) // hash documents that Write will never return an error
|
||||||
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
|
sum := hash.Sum(nil)[:hash.Size()/2]
|
||||||
return base64.RawURLEncoding.EncodeToString(sum), nil
|
return base64.RawURLEncoding.EncodeToString(sum), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package op
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -19,6 +18,7 @@ type Authorizer interface {
|
||||||
Decoder() *schema.Decoder
|
Decoder() *schema.Decoder
|
||||||
Encoder() *schema.Encoder
|
Encoder() *schema.Encoder
|
||||||
Signer() Signer
|
Signer() Signer
|
||||||
|
Issuer() string
|
||||||
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ type ValidationAuthorizer interface {
|
||||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"))
|
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
|
||||||
// AuthRequestError(w, r, nil, )
|
// AuthRequestError(w, r, nil, )
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
|
|
||||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)))
|
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||||
validation = validater.ValidateAuthRequest
|
validation = validater.ValidateAuthRequest
|
||||||
}
|
}
|
||||||
if err := validation(authReq, authorizer.Storage()); err != nil {
|
if err := validation(authReq, authorizer.Storage()); err != nil {
|
||||||
AuthRequestError(w, r, authReq, err)
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, authReq, err)
|
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, req, err)
|
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
RedirectToLogin(req.GetID(), client, w, r)
|
RedirectToLogin(req.GetID(), client, w, r)
|
||||||
|
@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error {
|
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||||
if uri == "" {
|
if uri == "" {
|
||||||
return ErrInvalidRequest("redirect_uri must not be empty")
|
return ErrInvalidRequest("redirect_uri must not be empty")
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
||||||
|
|
||||||
authReq, err := authorizer.Storage().AuthRequestByID(id)
|
authReq, err := authorizer.Storage().AuthRequestByID(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
AuthRequestError(w, r, nil, err)
|
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
AuthResponse(authReq, authorizer, w, r)
|
AuthResponse(authReq, authorizer, w, r)
|
||||||
|
@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
||||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||||
var callback string
|
var callback string
|
||||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
|
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
|
||||||
} else {
|
} else {
|
||||||
var accessToken string
|
var accessToken string
|
||||||
var err error
|
var err error
|
||||||
|
var exp uint64
|
||||||
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
||||||
accessToken, err = CreateAccessToken()
|
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer())
|
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
||||||
}
|
}
|
||||||
resp := &oidc.AccessTokenResponse{
|
resp := &oidc.AccessTokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
TokenType: "Bearer",
|
TokenType: oidc.BearerToken,
|
||||||
|
ExpiresIn: exp,
|
||||||
}
|
}
|
||||||
values := make(map[string][]string)
|
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||||
authorizer.Encoder().Encode(resp, values)
|
if err != nil {
|
||||||
v := url.Values(values)
|
|
||||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
|
}
|
||||||
|
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, callback, http.StatusFound)
|
http.Redirect(w, r, callback, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
|
||||||
uri string
|
uri string
|
||||||
clientID string
|
clientID string
|
||||||
responseType oidc.ResponseType
|
responseType oidc.ResponseType
|
||||||
storage op.Storage
|
storage op.OPStorage
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -18,6 +18,8 @@ const (
|
||||||
|
|
||||||
authMethodBasic = "client_secret_basic"
|
authMethodBasic = "client_secret_basic"
|
||||||
authMethodPost = "client_secret_post"
|
authMethodPost = "client_secret_post"
|
||||||
|
|
||||||
|
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -28,7 +30,6 @@ var (
|
||||||
Userinfo: defaultUserinfoEndpoint,
|
Userinfo: defaultUserinfoEndpoint,
|
||||||
JwksURI: defaultKeysEndpoint,
|
JwksURI: defaultKeysEndpoint,
|
||||||
}
|
}
|
||||||
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type DefaultOP struct {
|
type DefaultOP struct {
|
||||||
|
@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("ok"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
|
||||||
// ClaimsSupported: oidc.SupportedClaims,
|
// ClaimsSupported: oidc.SupportedClaims,
|
||||||
IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
|
IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
|
||||||
SubjectTypesSupported: subjectTypes(c),
|
SubjectTypesSupported: subjectTypes(c),
|
||||||
TokenEndpointAuthMethodsSupported: authMethods(c),
|
TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string {
|
||||||
return []string{"public"} //TODO: config
|
return []string{"public"} //TODO: config
|
||||||
}
|
}
|
||||||
|
|
||||||
func authMethods(c Configuration) []string {
|
func authMethods(basic, post bool) []string {
|
||||||
authMethods := make([]string, 0, 2)
|
authMethods := make([]string, 0, 2)
|
||||||
if c.AuthMethodBasicSupported() {
|
if basic {
|
||||||
|
// if c.AuthMethodBasicSupported() {
|
||||||
authMethods = append(authMethods, authMethodBasic)
|
authMethods = append(authMethods, authMethodBasic)
|
||||||
}
|
}
|
||||||
if c.AuthMethodPostSupported() {
|
if post {
|
||||||
|
// if c.AuthMethodPostSupported() {
|
||||||
authMethods = append(authMethods, authMethodPost)
|
authMethods = append(authMethods, authMethodPost)
|
||||||
}
|
}
|
||||||
return authMethods
|
return authMethods
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package op_test
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
"github.com/caos/oidc/pkg/op"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDiscover(t *testing.T) {
|
func TestDiscover(t *testing.T) {
|
||||||
|
@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
op.Discover(tt.args.w, tt.args.config)
|
Discover(tt.args.w, tt.args.config)
|
||||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
|
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
|
||||||
|
@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) {
|
||||||
|
|
||||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
c op.Configuration
|
c Configuration
|
||||||
s op.Signer
|
s Signer
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_scopes(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
c Configuration
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("scopes() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_responseTypes(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
c Configuration
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_grantTypes(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
c Configuration
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func Test_sigAlgorithms(t *testing.T) {
|
||||||
|
// type args struct {
|
||||||
|
// s Signer
|
||||||
|
// }
|
||||||
|
// tests := []struct {
|
||||||
|
// name string
|
||||||
|
// args args
|
||||||
|
// want []string
|
||||||
|
// }{
|
||||||
|
// {
|
||||||
|
// "",
|
||||||
|
// args{},
|
||||||
|
// []string{"RS256"},
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// for _, tt := range tests {
|
||||||
|
// t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func Test_subjectTypes(t *testing.T) {
|
||||||
|
// type args struct {
|
||||||
|
// c Configuration
|
||||||
|
// }
|
||||||
|
// tests := []struct {
|
||||||
|
// name string
|
||||||
|
// args args
|
||||||
|
// want []string
|
||||||
|
// }{
|
||||||
|
// {
|
||||||
|
// "none",
|
||||||
|
// args{func()}
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// for _, tt := range tests {
|
||||||
|
// t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
// t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
func Test_authMethods(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
basic bool
|
||||||
|
post bool
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"none",
|
||||||
|
args{false, false},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"basic",
|
||||||
|
args{true, false},
|
||||||
|
[]string{authMethodBasic},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"post",
|
||||||
|
args{false, true},
|
||||||
|
[]string{authMethodPost},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"basic and post",
|
||||||
|
args{true, true},
|
||||||
|
[]string{authMethodBasic, authMethodPost},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("authMethods() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
100
pkg/op/error.go
100
pkg/op/error.go
|
@ -1,8 +1,10 @@
|
||||||
package op
|
package op
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
"github.com/gorilla/schema"
|
||||||
|
|
||||||
"github.com/caos/oidc/pkg/oidc"
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
"github.com/caos/oidc/pkg/utils"
|
"github.com/caos/oidc/pkg/utils"
|
||||||
|
@ -13,6 +15,21 @@ const (
|
||||||
ServerError errorType = "server_error"
|
ServerError errorType = "server_error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidRequest = func(description string) *OAuthError {
|
||||||
|
return &OAuthError{
|
||||||
|
ErrorType: InvalidRequest,
|
||||||
|
Description: description,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ErrServerError = func(description string) *OAuthError {
|
||||||
|
return &OAuthError{
|
||||||
|
ErrorType: ServerError,
|
||||||
|
Description: description,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
type errorType string
|
type errorType string
|
||||||
|
|
||||||
type ErrAuthRequest interface {
|
type ErrAuthRequest interface {
|
||||||
|
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
|
||||||
GetState() string
|
GetState() string
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) {
|
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
|
||||||
if authReq == nil {
|
if authReq == nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
e, ok := err.(*OAuthError)
|
||||||
|
if !ok {
|
||||||
|
e = new(OAuthError)
|
||||||
|
e.ErrorType = ServerError
|
||||||
|
e.Description = err.Error()
|
||||||
|
}
|
||||||
|
e.state = authReq.GetState()
|
||||||
|
params, err := utils.URLEncodeResponse(e, encoder)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
url := authReq.GetRedirectURI()
|
url := authReq.GetRedirectURI()
|
||||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||||
url += "?"
|
url += "?" + params
|
||||||
} else {
|
} else {
|
||||||
url += "#"
|
url += "#" + params
|
||||||
}
|
|
||||||
var errorType errorType
|
|
||||||
var description string
|
|
||||||
if e, ok := err.(*OAuthError); ok {
|
|
||||||
errorType = e.ErrorType
|
|
||||||
description = e.Description
|
|
||||||
} else {
|
|
||||||
errorType = ServerError
|
|
||||||
description = err.Error()
|
|
||||||
}
|
|
||||||
url += "error=" + string(errorType)
|
|
||||||
if description != "" {
|
|
||||||
url += "&error_description=" + description
|
|
||||||
}
|
|
||||||
if authReq.GetState() != "" {
|
|
||||||
url += "&state=" + authReq.GetState()
|
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, url, http.StatusFound)
|
http.Redirect(w, r, url, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthError struct {
|
type OAuthError struct {
|
||||||
ErrorType errorType `json:"error"`
|
ErrorType errorType `json:"error" schema:"error"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description" schema:"description"`
|
||||||
}
|
state string `json:"state" schema:"state"`
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
|
|
||||||
return &OAuthError{
|
|
||||||
ErrorType: InvalidRequest,
|
|
||||||
Description: description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ErrServerError = func(description string, args ...interface{}) *OAuthError {
|
|
||||||
return &OAuthError{
|
|
||||||
ErrorType: ServerError,
|
|
||||||
Description: description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
|
|
||||||
if authReq == nil {
|
|
||||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if authReq.GetRedirectURI() == "" {
|
|
||||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
callback := authReq.GetRedirectURI()
|
|
||||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
|
||||||
callback += "?"
|
|
||||||
} else {
|
|
||||||
callback += "#"
|
|
||||||
}
|
|
||||||
callback += "error=" + string(e.ErrorType)
|
|
||||||
if e.Description != "" {
|
|
||||||
callback += "&error_description=" + url.QueryEscape(e.Description)
|
|
||||||
}
|
|
||||||
if authReq.GetState() != "" {
|
|
||||||
callback += "&state=" + authReq.GetState()
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, callback, http.StatusFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OAuthError) Error() string {
|
func (e *OAuthError) Error() string {
|
||||||
return ""
|
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issuer mocks base method
|
||||||
|
func (m *MockAuthorizer) Issuer() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Issuer")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issuer indicates an expected call of Issuer
|
||||||
|
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
|
||||||
|
}
|
||||||
|
|
||||||
// Signer mocks base method
|
// Signer mocks base method
|
||||||
func (m *MockAuthorizer) Signer() op.Signer {
|
func (m *MockAuthorizer) Signer() op.Signer {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -12,14 +12,12 @@ import (
|
||||||
|
|
||||||
type OpenIDProvider interface {
|
type OpenIDProvider interface {
|
||||||
Configuration
|
Configuration
|
||||||
// Storage() Storage
|
|
||||||
HandleDiscovery(w http.ResponseWriter, r *http.Request)
|
HandleDiscovery(w http.ResponseWriter, r *http.Request)
|
||||||
HandleAuthorize(w http.ResponseWriter, r *http.Request)
|
HandleAuthorize(w http.ResponseWriter, r *http.Request)
|
||||||
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
||||||
HandleExchange(w http.ResponseWriter, r *http.Request)
|
HandleExchange(w http.ResponseWriter, r *http.Request)
|
||||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||||
// Storage() Storage
|
|
||||||
HttpHandler() *http.Server
|
HttpHandler() *http.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,9 +36,11 @@ type AuthRequest interface {
|
||||||
GetAudience() []string
|
GetAudience() []string
|
||||||
GetAuthTime() time.Time
|
GetAuthTime() time.Time
|
||||||
GetClientID() string
|
GetClientID() string
|
||||||
|
GetCode() string
|
||||||
GetNonce() string
|
GetNonce() string
|
||||||
GetRedirectURI() string
|
GetRedirectURI() string
|
||||||
GetResponseType() oidc.ResponseType
|
GetResponseType() oidc.ResponseType
|
||||||
|
GetScopes() []string
|
||||||
GetState() string
|
GetState() string
|
||||||
GetSubject() string
|
GetSubject() string
|
||||||
}
|
}
|
||||||
|
|
46
pkg/op/token.go
Normal file
46
pkg/op/token.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package op
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
|
||||||
|
var err error
|
||||||
|
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
|
||||||
|
exp := time.Duration(5 * time.Minute)
|
||||||
|
return accessToken, uint64(exp.Seconds()), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||||
|
var err error
|
||||||
|
exp := time.Now().UTC().Add(validity)
|
||||||
|
claims := &oidc.IDTokenClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: authReq.GetSubject(),
|
||||||
|
Audiences: authReq.GetAudience(),
|
||||||
|
Expiration: exp,
|
||||||
|
IssuedAt: time.Now().UTC(),
|
||||||
|
AuthTime: authReq.GetAuthTime(),
|
||||||
|
Nonce: authReq.GetNonce(),
|
||||||
|
AuthenticationContextClassReference: authReq.GetACR(),
|
||||||
|
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||||
|
AuthorizedParty: authReq.GetClientID(),
|
||||||
|
}
|
||||||
|
if accessToken != "" {
|
||||||
|
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if code != "" {
|
||||||
|
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return signer.SignIDToken(claims)
|
||||||
|
}
|
|
@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessToken, err := CreateAccessToken()
|
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer())
|
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
|
@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
resp := &oidc.AccessTokenResponse{
|
resp := &oidc.AccessTokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
|
TokenType: oidc.BearerToken,
|
||||||
|
ExpiresIn: exp,
|
||||||
}
|
}
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateAccessToken() (string, error) {
|
|
||||||
return "accessToken", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) {
|
|
||||||
var err error
|
|
||||||
exp := time.Now().UTC().Add(validity)
|
|
||||||
claims := &oidc.IDTokenClaims{
|
|
||||||
Issuer: issuer,
|
|
||||||
Subject: authReq.GetSubject(),
|
|
||||||
Audiences: authReq.GetAudience(),
|
|
||||||
Expiration: exp,
|
|
||||||
IssuedAt: time.Now().UTC(),
|
|
||||||
AuthTime: authReq.GetAuthTime(),
|
|
||||||
Nonce: authReq.GetNonce(),
|
|
||||||
AuthenticationContextClassReference: authReq.GetACR(),
|
|
||||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
|
||||||
AuthorizedParty: authReq.GetClientID(),
|
|
||||||
}
|
|
||||||
if accessToken != "" {
|
|
||||||
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer.SignIDToken(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
|
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
|
||||||
if tokenReq.ClientID == "" {
|
if tokenReq.ClientID == "" {
|
||||||
if !exchanger.AuthMethodBasicSupported() {
|
if !exchanger.AuthMethodBasicSupported() {
|
||||||
|
|
|
@ -64,7 +64,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.verifier == nil {
|
if p.verifier == nil {
|
||||||
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint
|
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
|
@ -110,6 +110,7 @@ func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
||||||
//handling the oauth2 code exchange, extracting and validating the id_token
|
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
|
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||||
token, err := p.oauthConfig.Exchange(ctx, code)
|
token, err := p.oauthConfig.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err //TODO: our error
|
return nil, err //TODO: our error
|
||||||
|
@ -124,7 +125,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc
|
||||||
return nil, err //TODO: err
|
return nil, err //TODO: err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil
|
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//AuthURL is the `RelayingParty` interface implementation
|
//AuthURL is the `RelayingParty` interface implementation
|
||||||
|
|
|
@ -443,7 +443,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
|
||||||
return nil //TODO: return error
|
return nil //TODO: return error
|
||||||
}
|
}
|
||||||
|
|
||||||
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
|
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,3 +55,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
|
||||||
|
values := make(map[string][]string)
|
||||||
|
err := encoder.Encode(resp, values)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
v := url.Values(values)
|
||||||
|
return v.Encode(), nil
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue