diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 385408a..8cf9e80 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -143,9 +143,9 @@ func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) { func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) { return s.key, nil } -func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) { +func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) { pubkey := s.key.Public() - return jose.JSONWebKeySet{ + return &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, }, diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go new file mode 100644 index 0000000..9ec8fa9 --- /dev/null +++ b/pkg/oidc/code_challenge.go @@ -0,0 +1,26 @@ +package oidc + +import ( + "crypto/sha256" + + "github.com/caos/oidc/pkg/utils" +) + +const ( + CodeChallengeMethodPlain CodeChallengeMethod = "plain" + CodeChallengeMethodS256 CodeChallengeMethod = "S256" +) + +type CodeChallengeMethod string + +type CodeChallenge struct { + Challenge string + Method CodeChallengeMethod +} + +func (c *CodeChallenge) Verify(codeVerifier string) bool { + if c.Method == CodeChallengeMethodS256 { + codeVerifier = utils.HashString(sha256.New(), codeVerifier) + } + return codeVerifier == c.Challenge +} diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 248f671..c00061d 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -1,14 +1,10 @@ package oidc import ( - "crypto/sha256" - "crypto/sha512" - "encoding/base64" "encoding/json" - "fmt" - "hash" "time" + "github.com/caos/oidc/pkg/utils" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" ) @@ -94,25 +90,10 @@ type Tokens struct { } func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { - hash, err := getHashAlgorithm(sigAlgorithm) + hash, err := utils.GetHashAlgorithm(sigAlgorithm) if err != nil { return "", err } - hash.Write([]byte(claim)) // hash documents that Write will never return an error - sum := hash.Sum(nil)[:hash.Size()/2] - return base64.RawURLEncoding.EncodeToString(sum), nil -} - -func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { - switch sigAlgorithm { - case jose.RS256, jose.ES256, jose.PS256: - return sha256.New(), nil - case jose.RS384, jose.ES384, jose.PS384: - return sha512.New384(), nil - case jose.RS512, jose.ES512, jose.PS512: - return sha512.New(), nil - default: - return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) - } + return utils.HashString(hash, claim), nil } diff --git a/pkg/op/client.go b/pkg/op/client.go index b584254..41a6d60 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -7,6 +7,7 @@ const ( ) type Client interface { + GetID() string RedirectURIs() []string ApplicationType() ApplicationType LoginURL(string) string diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index 87c3575..9757457 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) } +// GetID mocks base method +func (m *MockClient) GetID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID +func (mr *MockClientMockRecorder) GetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID)) +} + // LoginURL mocks base method func (m *MockClient) LoginURL(arg0 string) string { m.ctrl.T.Helper() diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 4ebcd82..22cd49b 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -36,18 +36,18 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder { } // AuthRequestByCode mocks base method -func (m *MockStorage) AuthRequestByCode(arg0 op.Client, arg1, arg2 string) (op.AuthRequest, error) { +func (m *MockStorage) AuthRequestByCode(arg0 string) (op.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "AuthRequestByCode", arg0) ret0, _ := ret[0].(op.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // AuthRequestByCode indicates an expected call of AuthRequestByCode -func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0) } // AuthRequestByID mocks base method @@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock } // GetKeySet mocks base method -func (m *MockStorage) GetKeySet() (go_jose_v2.JSONWebKeySet, error) { +func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetKeySet") - ret0, _ := ret[0].(go_jose_v2.JSONWebKeySet) + ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index ef8ac2d..81b5b68 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -11,11 +11,11 @@ import ( type AuthStorage interface { CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error) AuthRequestByID(string) (AuthRequest, error) - AuthRequestByCode(Client, string, string) (AuthRequest, error) + AuthRequestByCode(string) (AuthRequest, error) DeleteAuthRequestAndCode(string, string) error GetSigningKey() (*jose.SigningKey, error) - GetKeySet() (jose.JSONWebKeySet, error) + GetKeySet() (*jose.JSONWebKeySet, error) } type OPStorage interface { @@ -38,6 +38,7 @@ type AuthRequest interface { GetAuthTime() time.Time GetClientID() string GetCode() string + GetCodeChallenge() *oidc.CodeChallenge GetNonce() string GetRedirectURI() string GetResponseType() oidc.ResponseType diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go index 00ec3e6..5ece515 100644 --- a/pkg/op/tokenrequest.go +++ b/pkg/op/tokenrequest.go @@ -39,12 +39,17 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } - client, err := AuthorizeClient(r, tokenReq, exchanger) + authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code) if err != nil { ExchangeRequestError(w, r, err) return } - authReq, err := exchanger.Storage().AuthRequestByCode(client, tokenReq.Code, tokenReq.RedirectURI) + client, err := AuthorizeClient(r, tokenReq, authReq, exchanger) + if err != nil { + ExchangeRequestError(w, r, err) + return + } + err = ValidateAccessTokenRequest(tokenReq, client, authReq) if err != nil { ExchangeRequestError(w, r, err) return @@ -74,7 +79,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { utils.MarshalJSON(w, resp) } -func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) { +func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) { if tokenReq.ClientID == "" { if !exchanger.AuthMethodBasicSupported() { return nil, errors.New("basic not supported") @@ -92,11 +97,24 @@ func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchang return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret) } if tokenReq.CodeVerifier != "" { - return exchanger.Storage().AuthorizeClientIDCodeVerifier(tokenReq.ClientID, tokenReq.CodeVerifier) + if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) { + return nil, ErrInvalidRequest("code_challenge invalid") + } + return exchanger.Storage().GetClientByClientID(tokenReq.ClientID) } return nil, errors.New("Unimplemented") //TODO: impl } +func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error { + if client.GetID() != authReq.GetClientID() { + return ErrInvalidRequest("invalid auth code") + } + if tokenReq.RedirectURI != authReq.GetRedirectURI() { + return ErrInvalidRequest("redirect_uri does no correspond") + } + return nil +} + func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { return nil, errors.New("Unimplemented") //TODO: impl } diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go new file mode 100644 index 0000000..bfdfacb --- /dev/null +++ b/pkg/utils/hash.go @@ -0,0 +1,30 @@ +package utils + +import ( + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "fmt" + "hash" + + "gopkg.in/square/go-jose.v2" +) + +func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { + switch sigAlgorithm { + case jose.RS256, jose.ES256, jose.PS256: + return sha256.New(), nil + case jose.RS384, jose.ES384, jose.PS384: + return sha512.New384(), nil + case jose.RS512, jose.ES512, jose.PS512: + return sha512.New(), nil + default: + return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) + } +} + +func HashString(hash hash.Hash, s string) string { + hash.Write([]byte(s)) // hash documents that Write will never return an error + sum := hash.Sum(nil)[:hash.Size()/2] + return base64.RawURLEncoding.EncodeToString(sum) +}