code exchange fixes
This commit is contained in:
parent
85814fb69a
commit
20a90c71d9
9 changed files with 107 additions and 36 deletions
|
@ -143,9 +143,9 @@ func (s *AuthStorage) GetSigningKey() (*jose.SigningKey, error) {
|
||||||
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
func (s *AuthStorage) GetKey() (*rsa.PrivateKey, error) {
|
||||||
return s.key, nil
|
return s.key, nil
|
||||||
}
|
}
|
||||||
func (s *AuthStorage) GetKeySet() (jose.JSONWebKeySet, error) {
|
func (s *AuthStorage) GetKeySet() (*jose.JSONWebKeySet, error) {
|
||||||
pubkey := s.key.Public()
|
pubkey := s.key.Public()
|
||||||
return jose.JSONWebKeySet{
|
return &jose.JSONWebKeySet{
|
||||||
Keys: []jose.JSONWebKey{
|
Keys: []jose.JSONWebKey{
|
||||||
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
|
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
|
||||||
},
|
},
|
||||||
|
|
26
pkg/oidc/code_challenge.go
Normal file
26
pkg/oidc/code_challenge.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CodeChallengeMethodPlain CodeChallengeMethod = "plain"
|
||||||
|
CodeChallengeMethodS256 CodeChallengeMethod = "S256"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodeChallengeMethod string
|
||||||
|
|
||||||
|
type CodeChallenge struct {
|
||||||
|
Challenge string
|
||||||
|
Method CodeChallengeMethod
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CodeChallenge) Verify(codeVerifier string) bool {
|
||||||
|
if c.Method == CodeChallengeMethodS256 {
|
||||||
|
codeVerifier = utils.HashString(sha256.New(), codeVerifier)
|
||||||
|
}
|
||||||
|
return codeVerifier == c.Challenge
|
||||||
|
}
|
|
@ -1,14 +1,10 @@
|
||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/caos/oidc/pkg/utils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
@ -94,25 +90,10 @@ type Tokens struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||||
hash, err := getHashAlgorithm(sigAlgorithm)
|
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
hash.Write([]byte(claim)) // hash documents that Write will never return an error
|
return utils.HashString(hash, claim), nil
|
||||||
sum := hash.Sum(nil)[:hash.Size()/2]
|
|
||||||
return base64.RawURLEncoding.EncodeToString(sum), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
|
||||||
switch sigAlgorithm {
|
|
||||||
case jose.RS256, jose.ES256, jose.PS256:
|
|
||||||
return sha256.New(), nil
|
|
||||||
case jose.RS384, jose.ES384, jose.PS384:
|
|
||||||
return sha512.New384(), nil
|
|
||||||
case jose.RS512, jose.ES512, jose.PS512:
|
|
||||||
return sha512.New(), nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
|
GetID() string
|
||||||
RedirectURIs() []string
|
RedirectURIs() []string
|
||||||
ApplicationType() ApplicationType
|
ApplicationType() ApplicationType
|
||||||
LoginURL(string) string
|
LoginURL(string) string
|
||||||
|
|
|
@ -47,6 +47,20 @@ func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetID mocks base method
|
||||||
|
func (m *MockClient) GetID() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetID")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID indicates an expected call of GetID
|
||||||
|
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||||
|
}
|
||||||
|
|
||||||
// LoginURL mocks base method
|
// LoginURL mocks base method
|
||||||
func (m *MockClient) LoginURL(arg0 string) string {
|
func (m *MockClient) LoginURL(arg0 string) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -36,18 +36,18 @@ func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequestByCode mocks base method
|
// AuthRequestByCode mocks base method
|
||||||
func (m *MockStorage) AuthRequestByCode(arg0 op.Client, arg1, arg2 string) (op.AuthRequest, error) {
|
func (m *MockStorage) AuthRequestByCode(arg0 string) (op.AuthRequest, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1, arg2)
|
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0)
|
||||||
ret0, _ := ret[0].(op.AuthRequest)
|
ret0, _ := ret[0].(op.AuthRequest)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequestByCode indicates an expected call of AuthRequestByCode
|
// AuthRequestByCode indicates an expected call of AuthRequestByCode
|
||||||
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
|
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1, arg2)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequestByID mocks base method
|
// AuthRequestByID mocks base method
|
||||||
|
@ -140,10 +140,10 @@ func (mr *MockStorageMockRecorder) GetClientByClientID(arg0 interface{}) *gomock
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKeySet mocks base method
|
// GetKeySet mocks base method
|
||||||
func (m *MockStorage) GetKeySet() (go_jose_v2.JSONWebKeySet, error) {
|
func (m *MockStorage) GetKeySet() (*go_jose_v2.JSONWebKeySet, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetKeySet")
|
ret := m.ctrl.Call(m, "GetKeySet")
|
||||||
ret0, _ := ret[0].(go_jose_v2.JSONWebKeySet)
|
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,11 +11,11 @@ import (
|
||||||
type AuthStorage interface {
|
type AuthStorage interface {
|
||||||
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
CreateAuthRequest(*oidc.AuthRequest) (AuthRequest, error)
|
||||||
AuthRequestByID(string) (AuthRequest, error)
|
AuthRequestByID(string) (AuthRequest, error)
|
||||||
AuthRequestByCode(Client, string, string) (AuthRequest, error)
|
AuthRequestByCode(string) (AuthRequest, error)
|
||||||
DeleteAuthRequestAndCode(string, string) error
|
DeleteAuthRequestAndCode(string, string) error
|
||||||
|
|
||||||
GetSigningKey() (*jose.SigningKey, error)
|
GetSigningKey() (*jose.SigningKey, error)
|
||||||
GetKeySet() (jose.JSONWebKeySet, error)
|
GetKeySet() (*jose.JSONWebKeySet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OPStorage interface {
|
type OPStorage interface {
|
||||||
|
@ -38,6 +38,7 @@ type AuthRequest interface {
|
||||||
GetAuthTime() time.Time
|
GetAuthTime() time.Time
|
||||||
GetClientID() string
|
GetClientID() string
|
||||||
GetCode() string
|
GetCode() string
|
||||||
|
GetCodeChallenge() *oidc.CodeChallenge
|
||||||
GetNonce() string
|
GetNonce() string
|
||||||
GetRedirectURI() string
|
GetRedirectURI() string
|
||||||
GetResponseType() oidc.ResponseType
|
GetResponseType() oidc.ResponseType
|
||||||
|
|
|
@ -39,12 +39,17 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := AuthorizeClient(r, tokenReq, exchanger)
|
authReq, err := exchanger.Storage().AuthRequestByCode(tokenReq.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authReq, err := exchanger.Storage().AuthRequestByCode(client, tokenReq.Code, tokenReq.RedirectURI)
|
client, err := AuthorizeClient(r, tokenReq, authReq, exchanger)
|
||||||
|
if err != nil {
|
||||||
|
ExchangeRequestError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = ValidateAccessTokenRequest(tokenReq, client, authReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ExchangeRequestError(w, r, err)
|
ExchangeRequestError(w, r, err)
|
||||||
return
|
return
|
||||||
|
@ -74,7 +79,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||||
utils.MarshalJSON(w, resp)
|
utils.MarshalJSON(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (Client, error) {
|
func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, authReq AuthRequest, exchanger Exchanger) (Client, error) {
|
||||||
if tokenReq.ClientID == "" {
|
if tokenReq.ClientID == "" {
|
||||||
if !exchanger.AuthMethodBasicSupported() {
|
if !exchanger.AuthMethodBasicSupported() {
|
||||||
return nil, errors.New("basic not supported")
|
return nil, errors.New("basic not supported")
|
||||||
|
@ -92,11 +97,24 @@ func AuthorizeClient(r *http.Request, tokenReq *oidc.AccessTokenRequest, exchang
|
||||||
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
|
return exchanger.Storage().AuthorizeClientIDSecret(tokenReq.ClientID, tokenReq.ClientSecret)
|
||||||
}
|
}
|
||||||
if tokenReq.CodeVerifier != "" {
|
if tokenReq.CodeVerifier != "" {
|
||||||
return exchanger.Storage().AuthorizeClientIDCodeVerifier(tokenReq.ClientID, tokenReq.CodeVerifier)
|
if !authReq.GetCodeChallenge().Verify(tokenReq.CodeVerifier) {
|
||||||
|
return nil, ErrInvalidRequest("code_challenge invalid")
|
||||||
|
}
|
||||||
|
return exchanger.Storage().GetClientByClientID(tokenReq.ClientID)
|
||||||
}
|
}
|
||||||
return nil, errors.New("Unimplemented") //TODO: impl
|
return nil, errors.New("Unimplemented") //TODO: impl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidateAccessTokenRequest(tokenReq *oidc.AccessTokenRequest, client Client, authReq AuthRequest) error {
|
||||||
|
if client.GetID() != authReq.GetClientID() {
|
||||||
|
return ErrInvalidRequest("invalid auth code")
|
||||||
|
}
|
||||||
|
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||||
|
return ErrInvalidRequest("redirect_uri does no correspond")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
||||||
return nil, errors.New("Unimplemented") //TODO: impl
|
return nil, errors.New("Unimplemented") //TODO: impl
|
||||||
}
|
}
|
||||||
|
|
30
pkg/utils/hash.go
Normal file
30
pkg/utils/hash.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
|
||||||
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
||||||
|
switch sigAlgorithm {
|
||||||
|
case jose.RS256, jose.ES256, jose.PS256:
|
||||||
|
return sha256.New(), nil
|
||||||
|
case jose.RS384, jose.ES384, jose.PS384:
|
||||||
|
return sha512.New384(), nil
|
||||||
|
case jose.RS512, jose.ES512, jose.PS512:
|
||||||
|
return sha512.New(), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HashString(hash hash.Hash, s string) string {
|
||||||
|
hash.Write([]byte(s)) // hash documents that Write will never return an error
|
||||||
|
sum := hash.Sum(nil)[:hash.Size()/2]
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sum)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue