code exchange fixes

This commit is contained in:
Livio Amstutz 2019-12-12 16:04:34 +01:00
parent 85814fb69a
commit 20a90c71d9
9 changed files with 107 additions and 36 deletions

View file

@ -143,9 +143,9 @@ func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
return s.key, nil
}
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
pubkey := s.key.Public()
return jose.JSONWebKeySet{
return &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
},

View file

@ -0,0 +1,26 @@
package oidc
import (
"crypto/sha256"
"github.com/caos/oidc/pkg/utils"
)
const (
CodeChallengeMethodPlain CodeChallengeMethod = "plain"
CodeChallengeMethodS256 CodeChallengeMethod = "S256"
)
type CodeChallengeMethod string
type CodeChallenge struct {
Challenge string
Method CodeChallengeMethod
}
func (c *CodeChallenge) Verify(codeVerifier string) bool {
if c.Method == CodeChallengeMethodS256 {
codeVerifier = utils.HashString(sha256.New(), codeVerifier)
}
return codeVerifier == c.Challenge
}

View file

@ -1,14 +1,10 @@
package oidc
import (
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"fmt"
"hash"
"time"
"github.com/caos/oidc/pkg/utils"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
)
@ -94,25 +90,10 @@ type Tokens struct {
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := getHashAlgorithm(sigAlgorithm)
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
if err != nil {
return "", err
}
hash.Write([]byte(claim)) // hash documents that Write will never return an error
sum := hash.Sum(nil)[:hash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum), nil
}
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
return utils.HashString(hash, claim), nil
}

View file

@ -7,6 +7,7 @@ const (
)
type Client interface {
GetID() string
RedirectURIs() []string
ApplicationType() ApplicationType
LoginURL(string) string

View file

@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
// GetID mocks base method
func (m *MockClient) GetID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetID")
ret0, _ := ret[0].(string)
return ret0
}
// GetID indicates an expected call of GetID
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
}
// LoginURL mocks base method
func (m *MockClient) LoginURL(arg0 string) string {
m.ctrl.T.Helper()

View file

@ -36,18 +36,18 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
}
// AuthRequestByCode mocks base method
func (m *MockStorage) AuthRequestByCode(arg0 op.Client, arg1, arg2 string) (op.AuthRequest, error) {
func (m *MockStorage) AuthRequestByCode(arg0 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0)
ret0, _ := ret[0].(op.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AuthRequestByCode indicates an expected call of AuthRequestByCode
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0)
}
// AuthRequestByID mocks base method
@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock
}
// GetKeySet mocks base method
func (m *MockStorage) GetKeySet() (go_jose_v2.JSONWebKeySet, error) {
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet")
ret0, _ := ret[0].(go_jose_v2.JSONWebKeySet)
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}

View file

@ -11,11 +11,11 @@ import (
type AuthStorage interface {
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
AuthRequestByID(string) (AuthRequest, error)
AuthRequestByCode(Client, string, string) (AuthRequest, error)
AuthRequestByCode(string) (AuthRequest, error)
DeleteAuthRequestAndCode(string, string) error
GetSigningKey() (*jose.SigningKey, error)
GetKeySet() (jose.JSONWebKeySet, error)
GetKeySet() (*jose.JSONWebKeySet, error)
}
type OPStorage interface {
@ -38,6 +38,7 @@ type AuthRequest interface {
GetAuthTime() time.Time
GetClientID() string
GetCode() string
GetCodeChallenge() *oidc.CodeChallenge
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType

View file

@ -39,12 +39,17 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
return
}
client, err := AuthorizeClient(r, tokenReq, exchanger)
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
authReq, err := exchanger.Storage().AuthRequestByCode(client, tokenReq.Code, tokenReq.RedirectURI)
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
if err != nil {
ExchangeRequestError(w, r, err)
return
}
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
if err != nil {
ExchangeRequestError(w, r, err)
return
@ -74,7 +79,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp)
}
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
if tokenReq.ClientID == "" {
if !exchanger.AuthMethodBasicSupported() {
return nil, errors.New("basic not supported")
@ -92,11 +97,24 @@ func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchang
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
}
if tokenReq.CodeVerifier != "" {
return exchanger.Storage().AuthorizeClientIDCodeVerifier(tokenReq.ClientID, tokenReq.CodeVerifier)
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
return nil, ErrInvalidRequest("code_challenge invalid")
}
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
}
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
if client.GetID() != authReq.GetClientID() {
return ErrInvalidRequest("invalid auth code")
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return ErrInvalidRequest("redirect_uri does no correspond")
}
return nil
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
return nil, errors.New("Unimplemented") //TODO: impl
}

30
pkg/utils/hash.go Normal file
View file

@ -0,0 +1,30 @@
package utils
import (
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"gopkg.in/square/go-jose.v2"
)
func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256:
return sha256.New(), nil
case jose.RS384, jose.ES384, jose.PS384:
return sha512.New384(), nil
case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil
default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
}
}
func HashString(hash hash.Hash, s string) string {
hash.Write([]byte(s)) // hash documents that Write will never return an error
sum := hash.Sum(nil)[:hash.Size()/2]
return base64.RawURLEncoding.EncodeToString(sum)
}