refactoring
This commit is contained in:
parent
a793e77679
commit
310220d38e
17 changed files with 346 additions and 149 deletions
|
@ -41,7 +41,9 @@ func (a *AuthRequest) GetACR() string {
|
|||
}
|
||||
|
||||
func (a *AuthRequest) GetAMR() []string {
|
||||
return []string{}
|
||||
return []string{
|
||||
"password",
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAudience() []string {
|
||||
|
@ -55,7 +57,11 @@ func (a *AuthRequest) GetAuthTime() time.Time {
|
|||
}
|
||||
|
||||
func (a *AuthRequest) GetClientID() string {
|
||||
return ""
|
||||
return a.ID
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetCode() string {
|
||||
return "code"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
|
@ -63,23 +69,31 @@ func (a *AuthRequest) GetID() string {
|
|||
}
|
||||
|
||||
func (a *AuthRequest) GetNonce() string {
|
||||
return ""
|
||||
return "nonce"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return ""
|
||||
return "http://localhost:5556/auth/callback"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetResponseType() oidc.ResponseType {
|
||||
return a.ResponseType
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetScopes() []string {
|
||||
return []string{
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetState() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetSubject() string {
|
||||
return ""
|
||||
return "sub"
|
||||
}
|
||||
|
||||
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
|
@ -132,11 +146,14 @@ func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
|
|||
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||
}
|
||||
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
|
||||
pubkey := s.key.Public()
|
||||
return jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256"},
|
||||
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -151,6 +168,9 @@ func (c *ConfClient) RedirectURIs() []string {
|
|||
"http://localhost:9999/callback",
|
||||
"http://localhost:5556/auth/callback",
|
||||
"custom://callback",
|
||||
"https://localhost:8443/test/a/instructions-example/callback",
|
||||
"https://op.certification.openid.net:62054/authz_cb",
|
||||
"https://op.certification.openid.net:62054/authz_post",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ const (
|
|||
PromptSelectAccount = "select_account"
|
||||
|
||||
GrantTypeCode GrantType = "authorization_code"
|
||||
|
||||
BearerToken = "Bearer"
|
||||
)
|
||||
|
||||
var displayValues = map[string]Display{
|
||||
|
|
|
@ -14,17 +14,18 @@ import (
|
|||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audiences []string `json:"aud,omitempty"`
|
||||
Expiration time.Time `json:"exp,omitempty"`
|
||||
IssuedAt time.Time `json:"iat,omitempty"`
|
||||
AuthTime time.Time `json:"auth_time,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
IssuedAt time.Time
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
AuthenticationContextClassReference string
|
||||
AuthenticationMethodsReferences []string
|
||||
AuthorizedParty string
|
||||
AccessTokenHash string
|
||||
CodeHash string
|
||||
|
||||
Signature jose.SignatureAlgorithm //TODO: ???
|
||||
}
|
||||
|
@ -46,6 +47,7 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
|||
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
||||
t.AuthorizedParty = i.AuthorizedParty
|
||||
t.AccessTokenHash = i.AccessTokenHash
|
||||
t.CodeHash = i.CodeHash
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -63,6 +65,7 @@ func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
|||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||
AuthorizedParty: t.AuthorizedParty,
|
||||
AccessTokenHash: t.AccessTokenHash,
|
||||
CodeHash: t.CodeHash,
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
@ -81,21 +84,23 @@ type jsonIDToken struct {
|
|||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
}
|
||||
|
||||
type Tokens struct {
|
||||
*oauth2.Token
|
||||
IDTokenClaims *IDTokenClaims
|
||||
IDToken string
|
||||
}
|
||||
|
||||
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
tokenHash, err := getHashAlgorithm(sigAlgorithm)
|
||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
hash, err := getHashAlgorithm(sigAlgorithm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error
|
||||
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2]
|
||||
hash.Write([]byte(claim)) // hash documents that Write will never return an error
|
||||
sum := hash.Sum(nil)[:hash.Size()/2]
|
||||
return base64.RawURLEncoding.EncodeToString(sum), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ package op
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -19,6 +18,7 @@ type Authorizer interface {
|
|||
Decoder() *schema.Decoder
|
||||
Encoder() *schema.Encoder
|
||||
Signer() Signer
|
||||
Issuer() string
|
||||
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ type ValidationAuthorizer interface {
|
|||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"))
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
|
||||
// AuthRequestError(w, r, nil, )
|
||||
return
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
|
||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)))
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
|||
validation = validater.ValidateAuthRequest
|
||||
}
|
||||
if err := validation(authReq, authorizer.Storage()); err != nil {
|
||||
AuthRequestError(w, r, authReq, err)
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
req, err := authorizer.Storage().CreateAuthRequest(authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err)
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
|
||||
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, err)
|
||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
|
@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error {
|
||||
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||
if uri == "" {
|
||||
return ErrInvalidRequest("redirect_uri must not be empty")
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err)
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
|
@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
|
|||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
var callback string
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test")
|
||||
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
|
||||
} else {
|
||||
var accessToken string
|
||||
var err error
|
||||
var exp uint64
|
||||
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
|
||||
accessToken, err = CreateAccessToken()
|
||||
accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer())
|
||||
idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
resp := &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: "Bearer",
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}
|
||||
values := make(map[string][]string)
|
||||
authorizer.Encoder().Encode(resp, values)
|
||||
v := url.Values(values)
|
||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode())
|
||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
|
|||
uri string
|
||||
clientID string
|
||||
responseType oidc.ResponseType
|
||||
storage op.Storage
|
||||
storage op.OPStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -18,6 +18,8 @@ const (
|
|||
|
||||
authMethodBasic = "client_secret_basic"
|
||||
authMethodPost = "client_secret_post"
|
||||
|
||||
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -28,7 +30,6 @@ var (
|
|||
Userinfo: defaultUserinfoEndpoint,
|
||||
JwksURI: defaultKeysEndpoint,
|
||||
}
|
||||
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
|
||||
)
|
||||
|
||||
type DefaultOP struct {
|
||||
|
@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Write([]byte("ok"))
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
|
|||
// ClaimsSupported: oidc.SupportedClaims,
|
||||
IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
|
||||
SubjectTypesSupported: subjectTypes(c),
|
||||
TokenEndpointAuthMethodsSupported: authMethods(c),
|
||||
TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string {
|
|||
return []string{"public"} //TODO: config
|
||||
}
|
||||
|
||||
func authMethods(c Configuration) []string {
|
||||
func authMethods(basic, post bool) []string {
|
||||
authMethods := make([]string, 0, 2)
|
||||
if c.AuthMethodBasicSupported() {
|
||||
if basic {
|
||||
// if c.AuthMethodBasicSupported() {
|
||||
authMethods = append(authMethods, authMethodBasic)
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
if post {
|
||||
// if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, authMethodPost)
|
||||
}
|
||||
return authMethods
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package op_test
|
||||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
|
@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.Discover(tt.args.w, tt.args.config)
|
||||
Discover(tt.args.w, tt.args.config)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
|
||||
|
@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) {
|
|||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
s op.Signer
|
||||
c Configuration
|
||||
s Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_scopes(t *testing.T) {
|
||||
type args struct {
|
||||
c Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("scopes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_responseTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_grantTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// func Test_sigAlgorithms(t *testing.T) {
|
||||
// type args struct {
|
||||
// s Signer
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// args args
|
||||
// want []string
|
||||
// }{
|
||||
// {
|
||||
// "",
|
||||
// args{},
|
||||
// []string{"RS256"},
|
||||
// },
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
// func Test_subjectTypes(t *testing.T) {
|
||||
// type args struct {
|
||||
// c Configuration
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// args args
|
||||
// want []string
|
||||
// }{
|
||||
// {
|
||||
// "none",
|
||||
// args{func()}
|
||||
// }
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func Test_authMethods(t *testing.T) {
|
||||
type args struct {
|
||||
basic bool
|
||||
post bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"none",
|
||||
args{false, false},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"basic",
|
||||
args{true, false},
|
||||
[]string{authMethodBasic},
|
||||
},
|
||||
{
|
||||
"post",
|
||||
args{false, true},
|
||||
[]string{authMethodPost},
|
||||
},
|
||||
{
|
||||
"basic and post",
|
||||
args{true, true},
|
||||
[]string{authMethodBasic, authMethodPost},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("authMethods() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
100
pkg/op/error.go
100
pkg/op/error.go
|
@ -1,8 +1,10 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
@ -13,6 +15,21 @@ const (
|
|||
ServerError errorType = "server_error"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
|
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
|
|||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) {
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
|
||||
if authReq == nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
|
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
|
|||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
e.state = authReq.GetState()
|
||||
params, err := utils.URLEncodeResponse(e, encoder)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
url += "?"
|
||||
url += "?" + params
|
||||
} else {
|
||||
url += "#"
|
||||
}
|
||||
var errorType errorType
|
||||
var description string
|
||||
if e, ok := err.(*OAuthError); ok {
|
||||
errorType = e.ErrorType
|
||||
description = e.Description
|
||||
} else {
|
||||
errorType = ServerError
|
||||
description = err.Error()
|
||||
}
|
||||
url += "error=" + string(errorType)
|
||||
if description != "" {
|
||||
url += "&error_description=" + description
|
||||
}
|
||||
if authReq.GetState() != "" {
|
||||
url += "&state=" + authReq.GetState()
|
||||
url += "#" + params
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
|||
}
|
||||
|
||||
type OAuthError struct {
|
||||
ErrorType errorType `json:"error"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string, args ...interface{}) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
|
||||
if authReq == nil {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if authReq.GetRedirectURI() == "" {
|
||||
http.Error(w, e.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
callback := authReq.GetRedirectURI()
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
callback += "?"
|
||||
} else {
|
||||
callback += "#"
|
||||
}
|
||||
callback += "error=" + string(e.ErrorType)
|
||||
if e.Description != "" {
|
||||
callback += "&error_description=" + url.QueryEscape(e.Description)
|
||||
}
|
||||
if authReq.GetState() != "" {
|
||||
callback += "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
ErrorType errorType `json:"error" schema:"error"`
|
||||
Description string `json:"description" schema:"description"`
|
||||
state string `json:"state" schema:"state"`
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
return ""
|
||||
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
|
||||
}
|
||||
|
|
|
@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
|
||||
}
|
||||
|
||||
// Issuer mocks base method
|
||||
func (m *MockAuthorizer) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer
|
||||
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
|
||||
}
|
||||
|
||||
// Signer mocks base method
|
||||
func (m *MockAuthorizer) Signer() op.Signer {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -12,14 +12,12 @@ import (
|
|||
|
||||
type OpenIDProvider interface {
|
||||
Configuration
|
||||
// Storage() Storage
|
||||
HandleDiscovery(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorize(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
||||
HandleExchange(w http.ResponseWriter, r *http.Request)
|
||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||
// Storage() Storage
|
||||
HttpHandler() *http.Server
|
||||
}
|
||||
|
||||
|
|
|
@ -36,9 +36,11 @@ type AuthRequest interface {
|
|||
GetAudience() []string
|
||||
GetAuthTime() time.Time
|
||||
GetClientID() string
|
||||
GetCode() string
|
||||
GetNonce() string
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetScopes() []string
|
||||
GetState() string
|
||||
GetSubject() string
|
||||
}
|
||||
|
|
46
pkg/op/token.go
Normal file
46
pkg/op/token.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
|
||||
var err error
|
||||
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
|
||||
exp := time.Duration(5 * time.Minute)
|
||||
return accessToken, uint64(exp.Seconds()), err
|
||||
}
|
||||
|
||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||
var err error
|
||||
exp := time.Now().UTC().Add(validity)
|
||||
claims := &oidc.IDTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
AuthTime: authReq.GetAuthTime(),
|
||||
Nonce: authReq.GetNonce(),
|
||||
AuthenticationContextClassReference: authReq.GetACR(),
|
||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||
AuthorizedParty: authReq.GetClientID(),
|
||||
}
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if code != "" {
|
||||
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return signer.SignIDToken(claims)
|
||||
}
|
|
@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
accessToken, err := CreateAccessToken()
|
||||
accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer())
|
||||
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
|
@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
|||
resp := &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}
|
||||
utils.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
func CreateAccessToken() (string, error) {
|
||||
return "accessToken", nil
|
||||
}
|
||||
|
||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) {
|
||||
var err error
|
||||
exp := time.Now().UTC().Add(validity)
|
||||
claims := &oidc.IDTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
AuthTime: authReq.GetAuthTime(),
|
||||
Nonce: authReq.GetNonce(),
|
||||
AuthenticationContextClassReference: authReq.GetACR(),
|
||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||
AuthorizedParty: authReq.GetClientID(),
|
||||
}
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return signer.SignIDToken(claims)
|
||||
}
|
||||
|
||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
|
||||
if tokenReq.ClientID == "" {
|
||||
if !exchanger.AuthMethodBasicSupported() {
|
||||
|
|
|
@ -64,7 +64,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
|
|||
}
|
||||
|
||||
if p.verifier == nil {
|
||||
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint
|
||||
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
|
@ -110,6 +110,7 @@ func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
|||
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||
token, err := p.oauthConfig.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err //TODO: our error
|
||||
|
@ -124,7 +125,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc
|
|||
return nil, err //TODO: err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
|
|
|
@ -443,7 +443,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
|
|||
return nil //TODO: return error
|
||||
}
|
||||
|
||||
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm)
|
||||
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -55,3 +55,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
|
||||
values := make(map[string][]string)
|
||||
err := encoder.Encode(resp, values)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
v := url.Values(values)
|
||||
return v.Encode(), nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue