refactoring

This commit is contained in:
Livio Amstutz 2019-12-06 10:42:17 +01:00
parent a793e77679
commit 310220d38e
17 changed files with 346 additions and 149 deletions

View file

@ -41,7 +41,9 @@ func (a *AuthRequest) GetACR() string {
} }
func (a *AuthRequest) GetAMR() []string { func (a *AuthRequest) GetAMR() []string {
return []string{} return []string{
"password",
}
} }
func (a *AuthRequest) GetAudience() []string { func (a *AuthRequest) GetAudience() []string {
@ -55,7 +57,11 @@ func (a *AuthRequest) GetAuthTime() time.Time {
} }
func (a *AuthRequest) GetClientID() string { func (a *AuthRequest) GetClientID() string {
return "" return a.ID
}
func (a *AuthRequest) GetCode() string {
return "code"
} }
func (a *AuthRequest) GetID() string { func (a *AuthRequest) GetID() string {
@ -63,23 +69,31 @@ func (a *AuthRequest) GetID() string {
} }
func (a *AuthRequest) GetNonce() string { func (a *AuthRequest) GetNonce() string {
return "" return "nonce"
} }
func (a *AuthRequest) GetRedirectURI() string { func (a *AuthRequest) GetRedirectURI() string {
return "" return "http://localhost:5556/auth/callback"
} }
func (a *AuthRequest) GetResponseType() oidc.ResponseType { func (a *AuthRequest) GetResponseType() oidc.ResponseType {
return a.ResponseType return a.ResponseType
} }
func (a *AuthRequest) GetScopes() []string {
return []string{
"openid",
"profile",
"email",
}
}
func (a *AuthRequest) GetState() string { func (a *AuthRequest) GetState() string {
return "" return ""
} }
func (a *AuthRequest) GetSubject() string { func (a *AuthRequest) GetSubject() string {
return "" return "sub"
} }
func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) { func (s *AuthStorage) CreateAuthRequest(authReq *oidc.AuthRequest) (op.AuthRequest, error) {
@ -132,11 +146,14 @@ func (s *AuthStorage) AuthRequestByID(id string) (op.AuthRequest, error) {
func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
} }
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
return s.key, nil
}
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) { func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
pubkey := s.key.Public() pubkey := s.key.Public()
return jose.JSONWebKeySet{ return jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{ Keys: []jose.JSONWebKey{
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256"}, jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
}, },
}, nil }, nil
} }
@ -151,6 +168,9 @@ func (c *ConfClient) RedirectURIs() []string {
"http://localhost:9999/callback", "http://localhost:9999/callback",
"http://localhost:5556/auth/callback", "http://localhost:5556/auth/callback",
"custom://callback", "custom://callback",
"https://localhost:8443/test/a/instructions-example/callback",
"https://op.certification.openid.net:62054/authz_cb",
"https://op.certification.openid.net:62054/authz_post",
} }
} }

View file

@ -25,6 +25,8 @@ const (
PromptSelectAccount = "select_account" PromptSelectAccount = "select_account"
GrantTypeCode GrantType = "authorization_code" GrantTypeCode GrantType = "authorization_code"
BearerToken = "Bearer"
) )
var displayValues = map[string]Display{ var displayValues = map[string]Display{

View file

@ -14,17 +14,18 @@ import (
) )
type IDTokenClaims struct { type IDTokenClaims struct {
Issuer string `json:"iss,omitempty"` Issuer string
Subject string `json:"sub,omitempty"` Subject string
Audiences []string `json:"aud,omitempty"` Audiences []string
Expiration time.Time `json:"exp,omitempty"` Expiration time.Time
IssuedAt time.Time `json:"iat,omitempty"` IssuedAt time.Time
AuthTime time.Time `json:"auth_time,omitempty"` AuthTime time.Time
Nonce string `json:"nonce,omitempty"` Nonce string
AuthenticationContextClassReference string `json:"acr,omitempty"` AuthenticationContextClassReference string
AuthenticationMethodsReferences []string `json:"amr,omitempty"` AuthenticationMethodsReferences []string
AuthorizedParty string `json:"azp,omitempty"` AuthorizedParty string
AccessTokenHash string `json:"at_hash,omitempty"` AccessTokenHash string
CodeHash string
Signature jose.SignatureAlgorithm //TODO: ??? Signature jose.SignatureAlgorithm //TODO: ???
} }
@ -46,6 +47,7 @@ func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
t.AuthorizedParty = i.AuthorizedParty t.AuthorizedParty = i.AuthorizedParty
t.AccessTokenHash = i.AccessTokenHash t.AccessTokenHash = i.AccessTokenHash
t.CodeHash = i.CodeHash
return nil return nil
} }
@ -63,6 +65,7 @@ func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
AuthorizedParty: t.AuthorizedParty, AuthorizedParty: t.AuthorizedParty,
AccessTokenHash: t.AccessTokenHash, AccessTokenHash: t.AccessTokenHash,
CodeHash: t.CodeHash,
} }
return json.Marshal(j) return json.Marshal(j)
} }
@ -81,21 +84,23 @@ type jsonIDToken struct {
AuthenticationMethodsReferences []string `json:"amr,omitempty"` AuthenticationMethodsReferences []string `json:"amr,omitempty"`
AuthorizedParty string `json:"azp,omitempty"` AuthorizedParty string `json:"azp,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
} }
type Tokens struct { type Tokens struct {
*oauth2.Token *oauth2.Token
IDTokenClaims *IDTokenClaims IDTokenClaims *IDTokenClaims
IDToken string
} }
func AccessTokenHash(accessToken string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
tokenHash, err := getHashAlgorithm(sigAlgorithm) hash, err := getHashAlgorithm(sigAlgorithm)
if err != nil { if err != nil {
return "", err return "", err
} }
tokenHash.Write([]byte(accessToken)) // hash documents that Write will never return an error hash.Write([]byte(claim)) // hash documents that Write will never return an error
sum := tokenHash.Sum(nil)[:tokenHash.Size()/2] sum := hash.Sum(nil)[:hash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum), nil return base64.RawURLEncoding.EncodeToString(sum), nil
} }

View file

@ -3,7 +3,6 @@ package op
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
@ -19,6 +18,7 @@ type Authorizer interface {
Decoder() *schema.Decoder Decoder() *schema.Decoder
Encoder() *schema.Encoder Encoder() *schema.Encoder
Signer() Signer Signer() Signer
Issuer() string
// ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) // ErrorHandler() func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error)
} }
@ -37,7 +37,7 @@ type ValidationAuthorizer interface {
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form")) AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
// AuthRequestError(w, r, nil, ) // AuthRequestError(w, r, nil, )
return return
} }
@ -45,7 +45,7 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err = authorizer.Decoder().Decode(authReq, r.Form) err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil { if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err))) AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return return
} }
@ -54,19 +54,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
validation = validater.ValidateAuthRequest validation = validater.ValidateAuthRequest
} }
if err := validation(authReq, authorizer.Storage()); err != nil { if err := validation(authReq, authorizer.Storage()); err != nil {
AuthRequestError(w, r, authReq, err) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
req, err := authorizer.Storage().CreateAuthRequest(authReq) req, err := authorizer.Storage().CreateAuthRequest(authReq)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
client, err := authorizer.Storage().GetClientByClientID(req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, err) AuthRequestError(w, r, req, err, authorizer.Encoder())
return return
} }
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
@ -100,7 +100,7 @@ func ValidateAuthReqScopes(scopes []string) error {
return nil return nil
} }
func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage Storage) error { func ValidateAuthReqRedirectURI(uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
if uri == "" { if uri == "" {
return ErrInvalidRequest("redirect_uri must not be empty") return ErrInvalidRequest("redirect_uri must not be empty")
} }
@ -144,7 +144,7 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
authReq, err := authorizer.Storage().AuthRequestByID(id) authReq, err := authorizer.Storage().AuthRequestByID(id)
if err != nil { if err != nil {
AuthRequestError(w, r, nil, err) AuthRequestError(w, r, nil, err, authorizer.Encoder())
return return
} }
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
@ -153,29 +153,32 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
var callback string var callback string
if authReq.GetResponseType() == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), "test") callback = fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), authReq.GetCode())
} else { } else {
var accessToken string var accessToken string
var err error var err error
var exp uint64
if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly { if authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly {
accessToken, err = CreateAccessToken() accessToken, exp, err = CreateAccessToken(authReq, authorizer.Signer())
if err != nil { if err != nil {
} }
} }
idToken, err := CreateIDToken("", authReq, time.Duration(0), accessToken, authorizer.Signer()) idToken, err := CreateIDToken(authorizer.Issuer(), authReq, time.Duration(0), accessToken, "", authorizer.Signer())
if err != nil { if err != nil {
} }
resp := &oidc.AccessTokenResponse{ resp := &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
IDToken: idToken, IDToken: idToken,
TokenType: "Bearer", TokenType: oidc.BearerToken,
ExpiresIn: exp,
} }
values := make(map[string][]string) params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
authorizer.Encoder().Encode(resp, values) if err != nil {
v := url.Values(values)
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), v.Encode()) }
callback = fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }

View file

@ -148,7 +148,7 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
uri string uri string
clientID string clientID string
responseType oidc.ResponseType responseType oidc.ResponseType
storage op.Storage storage op.OPStorage
} }
tests := []struct { tests := []struct {
name string name string

View file

@ -18,6 +18,8 @@ const (
authMethodBasic = "client_secret_basic" authMethodBasic = "client_secret_basic"
authMethodPost = "client_secret_post" authMethodPost = "client_secret_post"
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
) )
var ( var (
@ -28,7 +30,6 @@ var (
Userinfo: defaultUserinfoEndpoint, Userinfo: defaultUserinfoEndpoint,
JwksURI: defaultKeysEndpoint, JwksURI: defaultKeysEndpoint,
} }
DefaultIDTokenValidity = time.Duration(5 * time.Minute)
) )
type DefaultOP struct { type DefaultOP struct {
@ -250,5 +251,5 @@ func (p *DefaultOP) handleTokenExchange(w http.ResponseWriter, r *http.Request)
} }
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
} }

View file

@ -27,7 +27,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
// ClaimsSupported: oidc.SupportedClaims, // ClaimsSupported: oidc.SupportedClaims,
IDTokenSigningAlgValuesSupported: sigAlgorithms(s), IDTokenSigningAlgValuesSupported: sigAlgorithms(s),
SubjectTypesSupported: subjectTypes(c), SubjectTypesSupported: subjectTypes(c),
TokenEndpointAuthMethodsSupported: authMethods(c), TokenEndpointAuthMethodsSupported: authMethods(c.AuthMethodBasicSupported(), c.AuthMethodPostSupported()),
} }
} }
@ -68,12 +68,14 @@ func subjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config return []string{"public"} //TODO: config
} }
func authMethods(c Configuration) []string { func authMethods(basic, post bool) []string {
authMethods := make([]string, 0, 2) authMethods := make([]string, 0, 2)
if c.AuthMethodBasicSupported() { if basic {
// if c.AuthMethodBasicSupported() {
authMethods = append(authMethods, authMethodBasic) authMethods = append(authMethods, authMethodBasic)
} }
if c.AuthMethodPostSupported() { if post {
// if c.AuthMethodPostSupported() {
authMethods = append(authMethods, authMethodPost) authMethods = append(authMethods, authMethodPost)
} }
return authMethods return authMethods

View file

@ -1,4 +1,4 @@
package op_test package op
import ( import (
"net/http" "net/http"
@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
) )
func TestDiscover(t *testing.T) { func TestDiscover(t *testing.T) {
@ -31,7 +30,7 @@ func TestDiscover(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
op.Discover(tt.args.w, tt.args.config) Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String()) require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
@ -41,8 +40,8 @@ func TestDiscover(t *testing.T) {
func TestCreateDiscoveryConfig(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) {
type args struct { type args struct {
c op.Configuration c Configuration
s op.Signer s Signer
} }
tests := []struct { tests := []struct {
name string name string
@ -53,9 +52,156 @@ func TestCreateDiscoveryConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) { if got := CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want) t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
} }
}) })
} }
} }
func Test_scopes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("scopes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_responseTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := responseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
}
})
}
}
func Test_grantTypes(t *testing.T) {
type args struct {
c Configuration
}
tests := []struct {
name string
args args
want []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := grantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
}
})
}
}
// func Test_sigAlgorithms(t *testing.T) {
// type args struct {
// s Signer
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "",
// args{},
// []string{"RS256"},
// },
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := sigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
// }
// })
// }
// }
// func Test_subjectTypes(t *testing.T) {
// type args struct {
// c Configuration
// }
// tests := []struct {
// name string
// args args
// want []string
// }{
// {
// "none",
// args{func()}
// }
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// if got := subjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
// t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
// }
// })
// }
// }
func Test_authMethods(t *testing.T) {
type args struct {
basic bool
post bool
}
tests := []struct {
name string
args args
want []string
}{
{
"none",
args{false, false},
[]string{},
},
{
"basic",
args{true, false},
[]string{authMethodBasic},
},
{
"post",
args{false, true},
[]string{authMethodPost},
},
{
"basic and post",
args{true, true},
[]string{authMethodBasic, authMethodPost},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := authMethods(tt.args.basic, tt.args.post); !reflect.DeepEqual(got, tt.want) {
t.Errorf("authMethods() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,8 +1,10 @@
package op package op
import ( import (
"fmt"
"net/http" "net/http"
"net/url"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
@ -13,6 +15,21 @@ const (
ServerError errorType = "server_error" ServerError errorType = "server_error"
) )
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string type errorType string
type ErrAuthRequest interface { type ErrAuthRequest interface {
@ -21,7 +38,7 @@ type ErrAuthRequest interface {
GetState() string GetState() string
} }
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error) { func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
if authReq == nil { if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -30,27 +47,23 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.state = authReq.GetState()
params, err := utils.URLEncodeResponse(e, encoder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.GetRedirectURI() url := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode { if authReq.GetResponseType() == oidc.ResponseTypeCode {
url += "?" url += "?" + params
} else { } else {
url += "#" url += "#" + params
}
var errorType errorType
var description string
if e, ok := err.(*OAuthError); ok {
errorType = e.ErrorType
description = e.Description
} else {
errorType = ServerError
description = err.Error()
}
url += "error=" + string(errorType)
if description != "" {
url += "&error_description=" + description
}
if authReq.GetState() != "" {
url += "&state=" + authReq.GetState()
} }
http.Redirect(w, r, url, http.StatusFound) http.Redirect(w, r, url, http.StatusFound)
} }
@ -67,50 +80,11 @@ func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
} }
type OAuthError struct { type OAuthError struct {
ErrorType errorType `json:"error"` ErrorType errorType `json:"error" schema:"error"`
Description string `json:"description"` Description string `json:"description" schema:"description"`
} state string `json:"state" schema:"state"`
var (
ErrInvalidRequest = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrServerError = func(description string, args ...interface{}) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
func (e *OAuthError) AuthRequestResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest) {
if authReq == nil {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
if authReq.GetRedirectURI() == "" {
http.Error(w, e.Error(), http.StatusBadRequest)
return
}
callback := authReq.GetRedirectURI()
if authReq.GetResponseType() == oidc.ResponseTypeCode {
callback += "?"
} else {
callback += "#"
}
callback += "error=" + string(e.ErrorType)
if e.Description != "" {
callback += "&error_description=" + url.QueryEscape(e.Description)
}
if authReq.GetState() != "" {
callback += "&state=" + authReq.GetState()
}
http.Redirect(w, r, callback, http.StatusFound)
} }
func (e *OAuthError) Error() string { func (e *OAuthError) Error() string {
return "" return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
} }

View file

@ -62,6 +62,20 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
} }
// Issuer mocks base method
func (m *MockAuthorizer) Issuer() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Issuer")
ret0, _ := ret[0].(string)
return ret0
}
// Issuer indicates an expected call of Issuer
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
}
// Signer mocks base method // Signer mocks base method
func (m *MockAuthorizer) Signer() op.Signer { func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -12,14 +12,12 @@ import (
type OpenIDProvider interface { type OpenIDProvider interface {
Configuration Configuration
// Storage() Storage
HandleDiscovery(w http.ResponseWriter, r *http.Request) HandleDiscovery(w http.ResponseWriter, r *http.Request)
HandleAuthorize(w http.ResponseWriter, r *http.Request) HandleAuthorize(w http.ResponseWriter, r *http.Request)
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
HandleExchange(w http.ResponseWriter, r *http.Request) HandleExchange(w http.ResponseWriter, r *http.Request)
HandleUserinfo(w http.ResponseWriter, r *http.Request) HandleUserinfo(w http.ResponseWriter, r *http.Request)
HandleKeys(w http.ResponseWriter, r *http.Request) HandleKeys(w http.ResponseWriter, r *http.Request)
// Storage() Storage
HttpHandler() *http.Server HttpHandler() *http.Server
} }

View file

@ -36,9 +36,11 @@ type AuthRequest interface {
GetAudience() []string GetAudience() []string
GetAuthTime() time.Time GetAuthTime() time.Time
GetClientID() string GetClientID() string
GetCode() string
GetNonce() string GetNonce() string
GetRedirectURI() string GetRedirectURI() string
GetResponseType() oidc.ResponseType GetResponseType() oidc.ResponseType
GetScopes() []string
GetState() string GetState() string
GetSubject() string GetSubject() string
} }

46
pkg/op/token.go Normal file
View file

@ -0,0 +1,46 @@
package op
import (
"fmt"
"time"
"github.com/caos/oidc/pkg/oidc"
)
func CreateAccessToken(authReq AuthRequest, signer Signer) (string, uint64, error) {
var err error
accessToken := fmt.Sprintf("%s:%s:%s:%s", authReq.GetSubject(), authReq.GetClientID(), authReq.GetAudience(), authReq.GetScopes())
exp := time.Duration(5 * time.Minute)
return accessToken, uint64(exp.Seconds()), err
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
if code != "" {
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}

View file

@ -54,12 +54,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
accessToken, err := CreateAccessToken() accessToken, exp, err := CreateAccessToken(authReq, exchanger.Signer())
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
} }
idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, exchanger.Signer()) idToken, err := CreateIDToken(exchanger.Issuer(), authReq, exchanger.IDTokenValidity(), accessToken, tokenReq.Code, exchanger.Signer())
if err != nil { if err != nil {
ExchangeRequestError(w, r, err) ExchangeRequestError(w, r, err)
return return
@ -68,39 +68,12 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
resp := &oidc.AccessTokenResponse{ resp := &oidc.AccessTokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
IDToken: idToken, IDToken: idToken,
TokenType: oidc.BearerToken,
ExpiresIn: exp,
} }
utils.MarshalJSON(w, resp) utils.MarshalJSON(w, resp)
} }
func CreateAccessToken() (string, error) {
return "accessToken", nil
}
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken string, signer Signer) (string, error) {
var err error
exp := time.Now().UTC().Add(validity)
claims := &oidc.IDTokenClaims{
Issuer: issuer,
Subject: authReq.GetSubject(),
Audiences: authReq.GetAudience(),
Expiration: exp,
IssuedAt: time.Now().UTC(),
AuthTime: authReq.GetAuthTime(),
Nonce: authReq.GetNonce(),
AuthenticationContextClassReference: authReq.GetACR(),
AuthenticationMethodsReferences: authReq.GetAMR(),
AuthorizedParty: authReq.GetClientID(),
}
if accessToken != "" {
claims.AccessTokenHash, err = oidc.AccessTokenHash(accessToken, signer.SignatureAlgorithm())
if err != nil {
return "", err
}
}
return signer.SignIDToken(claims)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) { func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" { if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() { if !exchanger.AuthMethodBasicSupported() {

View file

@ -64,7 +64,7 @@ func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExc
} }
if p.verifier == nil { if p.verifier == nil {
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) //TODO: keys endpoint p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
} }
return p, nil return p, nil
@ -110,6 +110,7 @@ func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
//handling the oauth2 code exchange, extracting and validating the id_token //handling the oauth2 code exchange, extracting and validating the id_token
//returning it paresed together with the oauth2 tokens (access, refresh) //returning it paresed together with the oauth2 tokens (access, refresh)
func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) { func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc.Tokens, err error) {
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
token, err := p.oauthConfig.Exchange(ctx, code) token, err := p.oauthConfig.Exchange(ctx, code)
if err != nil { if err != nil {
return nil, err //TODO: our error return nil, err //TODO: our error
@ -124,7 +125,7 @@ func (p *DefaultRP) CodeExchange(ctx context.Context, code string) (tokens *oidc
return nil, err //TODO: err return nil, err //TODO: err
} }
return &oidc.Tokens{Token: token, IDTokenClaims: idToken}, nil return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
} }
//AuthURL is the `RelayingParty` interface implementation //AuthURL is the `RelayingParty` interface implementation

View file

@ -443,7 +443,7 @@ func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgor
return nil //TODO: return error return nil //TODO: return error
} }
actual, err := oidc.AccessTokenHash(accessToken, sigAlgorithm) actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
if err != nil { if err != nil {
return err return err
} }

View file

@ -55,3 +55,13 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e
} }
return nil return nil
} }
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
values := make(map[string][]string)
err := encoder.Encode(resp, values)
if err != nil {
return "", err
}
v := url.Values(values)
return v.Encode(), nil
}