diff --git a/example/client/app/app.go b/example/client/app/app.go index d835959..6db1597 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -12,13 +12,13 @@ import ( "github.com/sirupsen/logrus" "github.com/caos/oidc/pkg/client/rp" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) var ( - callbackPath string = "/auth/callback" - key []byte = []byte("test1234test1234") + callbackPath = "/auth/callback" + key = []byte("test1234test1234") ) func main() { @@ -30,7 +30,7 @@ func main() { scopes := strings.Split(os.Getenv("SCOPES"), " ") redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) options := []rp.Option{ rp.WithCookieHandler(cookieHandler), diff --git a/example/client/github/github.go b/example/client/github/github.go index 35c7723..45f16c1 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -12,12 +12,12 @@ import ( "github.com/caos/oidc/pkg/client/rp" "github.com/caos/oidc/pkg/client/rp/cli" - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/http" ) var ( - callbackPath string = "/orbctl/github/callback" - key []byte = []byte("test1234test1234") + callbackPath = "/orbctl/github/callback" + key = []byte("test1234test1234") ) func main() { @@ -34,7 +34,7 @@ func main() { } ctx := context.Background() - cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + cookieHandler := http.NewCookieHandler(key, key, http.WithUnsecure()) relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, rp.WithCookieHandler(cookieHandler)) if err != nil { fmt.Printf("error creating relaying party: %v", err) diff --git a/example/client/service/service.go b/example/client/service/service.go index 34d959d..818b481 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -17,7 +17,7 @@ import ( ) var ( - client *http.Client = http.DefaultClient + client = http.DefaultClient ) func main() { diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go index 775f757..570e8a5 100644 --- a/example/internal/mock/storage.go +++ b/example/internal/mock/storage.go @@ -32,6 +32,7 @@ func NewAuthStorage() op.Storage { type AuthRequest struct { ID string ResponseType oidc.ResponseType + ResponseMode oidc.ResponseMode RedirectURI string Nonce string ClientID string @@ -88,6 +89,10 @@ func (a *AuthRequest) GetResponseType() oidc.ResponseType { return a.ResponseType } +func (a *AuthRequest) GetResponseMode() oidc.ResponseMode { + return a.ResponseMode +} + func (a *AuthRequest) GetScopes() []string { return []string{ "openid", @@ -170,6 +175,11 @@ func (s *AuthStorage) TokenRequestByRefreshToken(ctx context.Context, refreshTok func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error { return nil } + +func (s *AuthStorage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error { + return nil +} + func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) { keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key} } @@ -289,7 +299,7 @@ func (c *ConfClient) AuthMethod() oidc.AuthMethod { } func (c *ConfClient) IDTokenLifetime() time.Duration { - return time.Duration(5 * time.Minute) + return 5 * time.Minute } func (c *ConfClient) AccessTokenType() op.AccessTokenType { return c.accessTokenType diff --git a/pkg/client/client.go b/pkg/client/client.go index fa64b70..1828d1d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -10,12 +10,13 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" + "github.com/caos/oidc/pkg/crypto" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) var ( - Encoder = func() utils.Encoder { + Encoder = func() httphelper.Encoder { e := schema.NewEncoder() e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string { return value.Interface().(oidc.SpaceDelimitedArray).Encode() @@ -32,7 +33,7 @@ func Discover(issuer string, httpClient *http.Client) (*oidc.DiscoveryConfigurat return nil, err } discoveryConfig := new(oidc.DiscoveryConfiguration) - err = utils.HttpRequest(httpClient, req, &discoveryConfig) + err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) if err != nil { return nil, err } @@ -52,12 +53,12 @@ func CallTokenEndpoint(request interface{}, caller tokenEndpointCaller) (newToke } func callTokenEndpoint(request interface{}, authFn interface{}, caller tokenEndpointCaller) (newToken *oauth2.Token, err error) { - req, err := utils.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) + req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } tokenRes := new(oidc.AccessTokenResponse) - if err := utils.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { + if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { return nil, err } return &oauth2.Token{ @@ -69,7 +70,7 @@ func callTokenEndpoint(request interface{}, authFn interface{}, caller tokenEndp } func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) { - privateKey, err := utils.BytesToPrivateKey(key) + privateKey, err := crypto.BytesToPrivateKey(key) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) func SignedJWTProfileAssertion(clientID string, audience []string, expiration time.Duration, signer jose.Signer) (string, error) { iat := time.Now() exp := iat.Add(expiration) - return utils.Sign(&oidc.JWTTokenRequest{ + return crypto.Sign(&oidc.JWTTokenRequest{ Issuer: clientID, Subject: clientID, Audience: audience, diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 8095588..e120541 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -1,17 +1,16 @@ package client import ( - "context" "net/url" "golang.org/x/oauth2" + "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) //JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller tokenEndpointCaller) (*oauth2.Token, error) { +func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller tokenEndpointCaller) (*oauth2.Token, error) { return CallTokenEndpoint(jwtProfileGrantRequest, caller) } @@ -22,7 +21,7 @@ func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { } } -func ClientAssertionFormAuthorization(assertion string) utils.FormAuthorization { +func ClientAssertionFormAuthorization(assertion string) http.FormAuthorization { return func(values url.Values) { values.Set("client_assertion", assertion) values.Set("client_assertion_type", oidc.ClientAssertionTypeJWTAssertion) diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index 46a0fe9..6b7db2c 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -89,5 +89,5 @@ func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { if err != nil { return nil, err } - return client.JWTProfileExchange(nil, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) + return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) } diff --git a/pkg/utils/browser.go b/pkg/client/rp/cli/browser.go similarity index 96% rename from pkg/utils/browser.go rename to pkg/client/rp/cli/browser.go index dca75e4..1948427 100644 --- a/pkg/utils/browser.go +++ b/pkg/client/rp/cli/browser.go @@ -1,4 +1,4 @@ -package utils +package cli import ( "fmt" diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go index 89566eb..aba1546 100644 --- a/pkg/client/rp/cli/cli.go +++ b/pkg/client/rp/cli/cli.go @@ -5,8 +5,8 @@ import ( "net/http" "github.com/caos/oidc/pkg/client/rp" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) const ( @@ -28,9 +28,9 @@ func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, p http.Handle(loginPath, rp.AuthURLHandler(stateProvider, relyingParty)) http.Handle(callbackPath, rp.CodeExchangeHandler(callback, relyingParty)) - utils.StartServer(codeflowCtx, ":"+port) + httphelper.StartServer(codeflowCtx, ":"+port) - utils.OpenBrowser("http://localhost:" + port + loginPath) + OpenBrowser("http://localhost:" + port + loginPath) return <-tokenChan } diff --git a/pkg/client/rp/delegation.go b/pkg/client/rp/delegation.go index 3ae6bb6..73edd96 100644 --- a/pkg/client/rp/delegation.go +++ b/pkg/client/rp/delegation.go @@ -5,8 +5,8 @@ import ( ) //DelegationTokenRequest is an implementation of TokenExchangeRequest -//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional -//"urn:ietf:params:oauth:token-type:access_token" actor token for a +//it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional +//"urn:ietf:params:oauth:token-type:access_token" actor token for an //"urn:ietf:params:oauth:token-type:access_token" delegation token func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest { return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...) diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 4062ab4..78f9580 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -7,9 +7,9 @@ import ( "net/http" "sync" - "github.com/caos/oidc/pkg/utils" "gopkg.in/square/go-jose.v2" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" ) @@ -207,7 +207,7 @@ func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, } keySet := new(jsonWebKeySet) - if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil { + if err = httphelper.HttpRequest(r.httpClient, req, keySet); err != nil { return nil, fmt.Errorf("oidc: failed to get keys: %v", err) } return keySet.Keys, nil diff --git a/pkg/client/rp/relaying_party.go b/pkg/client/rp/relaying_party.go index b9b568d..23c37fc 100644 --- a/pkg/client/rp/relaying_party.go +++ b/pkg/client/rp/relaying_party.go @@ -13,8 +13,8 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/caos/oidc/pkg/client" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) const ( @@ -39,7 +39,7 @@ type RelyingParty interface { IsPKCE() bool //CookieHandler returns a http cookie handler used for various state transfer cookies - CookieHandler() *utils.CookieHandler + CookieHandler() *httphelper.CookieHandler //HttpClient returns a http client used for calls to the openid provider, e.g. calling token endpoint HttpClient() *http.Client @@ -76,7 +76,7 @@ type relyingParty struct { pkce bool httpClient *http.Client - cookieHandler *utils.CookieHandler + cookieHandler *httphelper.CookieHandler errorHandler func(http.ResponseWriter, *http.Request, string, string, string) idTokenVerifier IDTokenVerifier @@ -96,7 +96,7 @@ func (rp *relyingParty) IsPKCE() bool { return rp.pkce } -func (rp *relyingParty) CookieHandler() *utils.CookieHandler { +func (rp *relyingParty) CookieHandler() *httphelper.CookieHandler { return rp.cookieHandler } @@ -136,7 +136,7 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) { rp := &relyingParty{ oauthConfig: config, - httpClient: utils.DefaultHTTPClient, + httpClient: httphelper.DefaultHTTPClient, oauth2Only: true, } @@ -161,7 +161,7 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco RedirectURL: redirectURI, Scopes: scopes, }, - httpClient: utils.DefaultHTTPClient, + httpClient: httphelper.DefaultHTTPClient, oauth2Only: false, } @@ -181,11 +181,11 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco return rp, nil } -//Option is the type for providing dynamic options to the DefaultRP +//Option is the type for providing dynamic options to the relyingParty type Option func(*relyingParty) error //WithCookieHandler set a `CookieHandler` for securing the various redirects -func WithCookieHandler(cookieHandler *utils.CookieHandler) Option { +func WithCookieHandler(cookieHandler *httphelper.CookieHandler) Option { return func(rp *relyingParty) error { rp.cookieHandler = cookieHandler return nil @@ -195,7 +195,7 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) Option { //WithPKCE sets the RP to use PKCE (oauth2 code challenge) //it also sets a `CookieHandler` for securing the various redirects //and exchanging the code challenge -func WithPKCE(cookieHandler *utils.CookieHandler) Option { +func WithPKCE(cookieHandler *httphelper.CookieHandler) Option { return func(rp *relyingParty) error { rp.pkce = true rp.cookieHandler = cookieHandler @@ -246,7 +246,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { return Endpoints{}, err } discoveryConfig := new(oidc.DiscoveryConfiguration) - err = utils.HttpRequest(httpClient, req, &discoveryConfig) + err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) if err != nil { return Endpoints{}, err } @@ -395,7 +395,7 @@ func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, } req.Header.Set("authorization", tokenType+" "+token) userinfo := oidc.NewUserInfo() - if err := utils.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { + if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { return nil, err } if userinfo.GetSubject() != subject { diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 551fe88..224442f 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -7,8 +7,8 @@ import ( "time" "github.com/caos/oidc/pkg/client" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type ResourceServer interface { @@ -39,7 +39,7 @@ func (r *resourceServer) AuthFn() (interface{}, error) { func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { authorizer := func() (interface{}, error) { - return utils.AuthorizeBasic(clientID, clientSecret), nil + return httphelper.AuthorizeBasic(clientID, clientSecret), nil } return newResourceServer(issuer, authorizer, option...) } @@ -61,7 +61,7 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { rs := &resourceServer{ issuer: issuer, - httpClient: utils.DefaultHTTPClient, + httpClient: httphelper.DefaultHTTPClient, } for _, optFunc := range options { optFunc(rs) @@ -111,12 +111,12 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr if err != nil { return nil, err } - req, err := utils.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) + req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { return nil, err } resp := oidc.NewIntrospectionResponse() - if err := utils.HttpRequest(rp.HttpClient(), req, resp); err != nil { + if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { return nil, err } return resp, nil diff --git a/pkg/utils/crypto.go b/pkg/crypto/crypto.go similarity index 91% rename from pkg/utils/crypto.go rename to pkg/crypto/crypto.go index 3ca4963..488d8a4 100644 --- a/pkg/utils/crypto.go +++ b/pkg/crypto/crypto.go @@ -1,4 +1,4 @@ -package utils +package crypto import ( "crypto/aes" @@ -9,6 +9,10 @@ import ( "io" ) +var ( + ErrCipherTextBlockSize = errors.New("ciphertext block size is too short") +) + func EncryptAES(data string, key string) (string, error) { encrypted, err := EncryptBytesAES([]byte(data), key) if err != nil { @@ -55,8 +59,7 @@ func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) { } if len(cipherText) < aes.BlockSize { - err = errors.New("Ciphertext block size is too short!") - return nil, err + return nil, ErrCipherTextBlockSize } iv := cipherText[:aes.BlockSize] cipherText = cipherText[aes.BlockSize:] diff --git a/pkg/utils/hash.go b/pkg/crypto/hash.go similarity index 80% rename from pkg/utils/hash.go rename to pkg/crypto/hash.go index 5dae03c..6529249 100644 --- a/pkg/utils/hash.go +++ b/pkg/crypto/hash.go @@ -1,15 +1,20 @@ -package utils +package crypto import ( "crypto/sha256" "crypto/sha512" "encoding/base64" + "errors" "fmt" "hash" "gopkg.in/square/go-jose.v2" ) +var ( + ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm") +) + func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { switch sigAlgorithm { case jose.RS256, jose.ES256, jose.PS256: @@ -19,7 +24,7 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { case jose.RS512, jose.ES512, jose.PS512: return sha512.New(), nil default: - return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) + return nil, fmt.Errorf("%w: %q", ErrUnsupportedAlgorithm, sigAlgorithm) } } diff --git a/pkg/utils/key.go b/pkg/crypto/key.go similarity index 61% rename from pkg/utils/key.go rename to pkg/crypto/key.go index 7965c85..d75d1ab 100644 --- a/pkg/utils/key.go +++ b/pkg/crypto/key.go @@ -1,4 +1,4 @@ -package utils +package crypto import ( "crypto/rsa" @@ -8,15 +8,7 @@ import ( func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(priv) - enc := x509.IsEncryptedPEMBlock(block) b := block.Bytes - var err error - if enc { - b, err = x509.DecryptPEMBlock(block, nil) - if err != nil { - return nil, err - } - } key, err := x509.ParsePKCS1PrivateKey(b) if err != nil { return nil, err diff --git a/pkg/utils/sign.go b/pkg/crypto/sign.go similarity index 97% rename from pkg/utils/sign.go rename to pkg/crypto/sign.go index 5ebac43..a0b9cae 100644 --- a/pkg/utils/sign.go +++ b/pkg/crypto/sign.go @@ -1,4 +1,4 @@ -package utils +package crypto import ( "encoding/json" diff --git a/pkg/utils/cookie.go b/pkg/http/cookie.go similarity index 99% rename from pkg/utils/cookie.go rename to pkg/http/cookie.go index 9e73e08..62ea295 100644 --- a/pkg/utils/cookie.go +++ b/pkg/http/cookie.go @@ -1,4 +1,4 @@ -package utils +package http import ( "errors" diff --git a/pkg/utils/http.go b/pkg/http/http.go similarity index 97% rename from pkg/utils/http.go rename to pkg/http/http.go index 27f96f9..2512707 100644 --- a/pkg/utils/http.go +++ b/pkg/http/http.go @@ -1,4 +1,4 @@ -package utils +package http import ( "context" @@ -14,7 +14,7 @@ import ( var ( DefaultHTTPClient = &http.Client{ - Timeout: time.Duration(30 * time.Second), + Timeout: 30 * time.Second, } ) diff --git a/pkg/utils/marshal.go b/pkg/http/marshal.go similarity index 70% rename from pkg/utils/marshal.go rename to pkg/http/marshal.go index 8c04588..794a28a 100644 --- a/pkg/utils/marshal.go +++ b/pkg/http/marshal.go @@ -1,24 +1,26 @@ -package utils +package http import ( "bytes" "encoding/json" "fmt" "net/http" - - "github.com/sirupsen/logrus" + "reflect" ) func MarshalJSON(w http.ResponseWriter, i interface{}) { - b, err := json.Marshal(i) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + MarshalJSONWithStatus(w, i, http.StatusOK) +} + +func MarshalJSONWithStatus(w http.ResponseWriter, i interface{}, status int) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(status) + if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) { return } - w.Header().Set("content-type", "application/json") - _, err = w.Write(b) + err := json.NewEncoder(w).Encode(i) if err != nil { - logrus.Error("error writing response") + http.Error(w, err.Error(), http.StatusInternalServerError) } } diff --git a/pkg/utils/marshal_test.go b/pkg/http/marshal_test.go similarity index 59% rename from pkg/utils/marshal_test.go rename to pkg/http/marshal_test.go index bfc8275..3838a44 100644 --- a/pkg/utils/marshal_test.go +++ b/pkg/http/marshal_test.go @@ -1,8 +1,11 @@ -package utils +package http import ( "bytes" + "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" ) func TestConcatenateJSON(t *testing.T) { @@ -88,3 +91,66 @@ func TestConcatenateJSON(t *testing.T) { }) } } + +func TestMarshalJSONWithStatus(t *testing.T) { + type args struct { + i interface{} + status int + } + type res struct { + statusCode int + body string + } + tests := []struct { + name string + args args + res res + }{ + { + "empty ok", + args{ + nil, + 200, + }, + res{ + 200, + "", + }, + }, + { + "string ok", + args{ + "ok", + 200, + }, + res{ + 200, + `"ok" +`, + }, + }, + { + "struct ok", + args{ + struct { + Test string `json:"test"` + }{"ok"}, + 200, + }, + res{ + 200, + `{"test":"ok"} +`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + MarshalJSONWithStatus(w, tt.args.i, tt.args.status) + assert.Equal(t, tt.res.statusCode, w.Result().StatusCode) + assert.Equal(t, "application/json", w.Header().Get("content-type")) + assert.Equal(t, tt.res.body, w.Body.String()) + }) + } +} diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index 79d0c1e..e6cfe58 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -42,6 +42,9 @@ const ( DisplayTouch Display = "touch" DisplayWAP Display = "wap" + ResponseModeQuery ResponseMode = "query" + ResponseModeFragment ResponseMode = "fragment" + //PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages. //An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed PromptNone = "none" @@ -59,27 +62,28 @@ const ( //AuthRequest according to: //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type AuthRequest struct { - ID string - Scopes SpaceDelimitedArray `schema:"scope"` - ResponseType ResponseType `schema:"response_type"` - ClientID string `schema:"client_id"` - RedirectURI string `schema:"redirect_uri"` //TODO: type + Scopes SpaceDelimitedArray `json:"scope" schema:"scope"` + ResponseType ResponseType `json:"response_type" schema:"response_type"` + ClientID string `json:"client_id" schema:"client_id"` + RedirectURI string `json:"redirect_uri" schema:"redirect_uri"` - State string `schema:"state"` + State string `json:"state" schema:"state"` + Nonce string `json:"nonce" schema:"nonce"` - // ResponseMode TODO: ? + ResponseMode ResponseMode `json:"response_mode" schema:"response_mode"` + Display Display `json:"display" schema:"display"` + Prompt SpaceDelimitedArray `json:"prompt" schema:"prompt"` + MaxAge *uint `json:"max_age" schema:"max_age"` + UILocales Locales `json:"ui_locales" schema:"ui_locales"` + IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"` + LoginHint string `json:"login_hint" schema:"login_hint"` + ACRValues []string `json:"acr_values" schema:"acr_values"` - Nonce string `schema:"nonce"` - Display Display `schema:"display"` - Prompt SpaceDelimitedArray `schema:"prompt"` - MaxAge *uint `schema:"max_age"` - UILocales Locales `schema:"ui_locales"` - IDTokenHint string `schema:"id_token_hint"` - LoginHint string `schema:"login_hint"` - ACRValues []string `schema:"acr_values"` + CodeChallenge string `json:"code_challenge" schema:"code_challenge"` + CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"` - CodeChallenge string `schema:"code_challenge"` - CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` + //RequestParam enables OIDC requests to be passed in a single, self-contained parameter (as JWT, called Request Object) + RequestParam string `schema:"request"` } //GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go index 9c4c8a3..4e82feb 100644 --- a/pkg/oidc/code_challenge.go +++ b/pkg/oidc/code_challenge.go @@ -3,7 +3,7 @@ package oidc import ( "crypto/sha256" - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/crypto" ) const ( @@ -19,7 +19,7 @@ type CodeChallenge struct { } func NewSHACodeChallenge(code string) string { - return utils.HashString(sha256.New(), code, false) + return crypto.HashString(sha256.New(), code, false) } func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool { diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go index acab578..1a92f8d 100644 --- a/pkg/oidc/discovery.go +++ b/pkg/oidc/discovery.go @@ -9,49 +9,144 @@ const ( ) type DiscoveryConfiguration struct { - Issuer string `json:"issuer,omitempty"` - AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` - TokenEndpoint string `json:"token_endpoint,omitempty"` - IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` - UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` - RevocationEndpoint string `json:"revocation_endpoint,omitempty"` - EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` - CheckSessionIframe string `json:"check_session_iframe,omitempty"` - JwksURI string `json:"jwks_uri,omitempty"` - ScopesSupported []string `json:"scopes_supported,omitempty"` - ResponseTypesSupported []string `json:"response_types_supported,omitempty"` - ResponseModesSupported []string `json:"response_modes_supported,omitempty"` - GrantTypesSupported []GrantType `json:"grant_types_supported,omitempty"` - ACRValuesSupported []string `json:"acr_values_supported,omitempty"` - SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` - IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` - IDTokenEncryptionAlgValuesSupported []string `json:"id_token_encryption_alg_values_supported,omitempty"` - IDTokenEncryptionEncValuesSupported []string `json:"id_token_encryption_enc_values_supported,omitempty"` - UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported,omitempty"` - UserinfoEncryptionAlgValuesSupported []string `json:"userinfo_encryption_alg_values_supported,omitempty"` - UserinfoEncryptionEncValuesSupported []string `json:"userinfo_encryption_enc_values_supported,omitempty"` - RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported,omitempty"` - RequestObjectEncryptionAlgValuesSupported []string `json:"request_object_encryption_alg_values_supported,omitempty"` - RequestObjectEncryptionEncValuesSupported []string `json:"request_object_encryption_enc_values_supported,omitempty"` - TokenEndpointAuthMethodsSupported []AuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"` - TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` - RevocationEndpointAuthMethodsSupported []AuthMethod `json:"revocation_endpoint_auth_methods_supported,omitempty"` - RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` - IntrospectionEndpointAuthMethodsSupported []AuthMethod `json:"introspection_endpoint_auth_methods_supported,omitempty"` - IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` - DisplayValuesSupported []Display `json:"display_values_supported,omitempty"` - ClaimTypesSupported []string `json:"claim_types_supported,omitempty"` - ClaimsSupported []string `json:"claims_supported,omitempty"` - ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"` - CodeChallengeMethodsSupported []CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"` - ServiceDocumentation string `json:"service_documentation,omitempty"` - ClaimsLocalesSupported []language.Tag `json:"claims_locales_supported,omitempty"` - UILocalesSupported []language.Tag `json:"ui_locales_supported,omitempty"` - RequestParameterSupported bool `json:"request_parameter_supported,omitempty"` - RequestURIParameterSupported bool `json:"request_uri_parameter_supported"` //no omitempty because: If omitted, the default value is true - RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"` - OPPolicyURI string `json:"op_policy_uri,omitempty"` - OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"` + //Issuer is the identifier of the OP and is used in the tokens as `iss` claim. + Issuer string `json:"issuer,omitempty"` + + //AuthorizationEndpoint is the URL of the OAuth 2.0 Authorization Endpoint where all user interactive login start + AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` + + //TokenEndpoint is the URL of the OAuth 2.0 Token Endpoint where all tokens are issued, except when using Implicit Flow + TokenEndpoint string `json:"token_endpoint,omitempty"` + + //IntrospectionEndpoint is the URL of the OAuth 2.0 Introspection Endpoint. + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + + //UserinfoEndpoint is the URL where an access_token can be used to retrieve the Userinfo. + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` + + //RevocationEndpoint is the URL of the OAuth 2.0 Revocation Endpoint. + RevocationEndpoint string `json:"revocation_endpoint,omitempty"` + + //EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP. + EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` + + //CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client. + CheckSessionIframe string `json:"check_session_iframe,omitempty"` + + //JwksURI is the URL of the JSON Web Key Set. This site contains the signing keys that RPs can use to validate the signature. + //It may also contain the OP's encryption keys that RPs can use to encrypt request to the OP. + JwksURI string `json:"jwks_uri,omitempty"` + + //RegistrationEndpoint is the URL for the Dynamic Client Registration. + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + + //ScopesSupported lists an array of supported scopes. This list must not include every supported scope by the OP. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + //ResponseTypesSupported contains a list of the OAuth 2.0 response_type values that the OP supports (code, id_token, token id_token, ...). + ResponseTypesSupported []string `json:"response_types_supported,omitempty"` + + //ResponseModesSupported contains a list of the OAuth 2.0 response_mode values that the OP supports. If omitted, the default value is ["query", "fragment"]. + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + + //GrantTypesSupported contains a list of the OAuth 2.0 grant_type values that the OP supports. If omitted, the default value is ["authorization_code", "implicit"]. + GrantTypesSupported []GrantType `json:"grant_types_supported,omitempty"` + + //ACRValuesSupported contains a list of Authentication Context Class References that the OP supports. + ACRValuesSupported []string `json:"acr_values_supported,omitempty"` + + //SubjectTypesSupported contains a list of Subject Identifier types that the OP supports (pairwise, public). + SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` + + //IDTokenSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for the ID Token. + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` + + //IDTokenEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the ID Token. + IDTokenEncryptionAlgValuesSupported []string `json:"id_token_encryption_alg_values_supported,omitempty"` + + //IDTokenEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the ID Token. + IDTokenEncryptionEncValuesSupported []string `json:"id_token_encryption_enc_values_supported,omitempty"` + + //UserinfoSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for UserInfo Endpoint. + UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported,omitempty"` + + //UserinfoEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the UserInfo Endpoint. + UserinfoEncryptionAlgValuesSupported []string `json:"userinfo_encryption_alg_values_supported,omitempty"` + + //UserinfoEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the UserInfo Endpoint. + UserinfoEncryptionEncValuesSupported []string `json:"userinfo_encryption_enc_values_supported,omitempty"` + + //RequestObjectSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for Request Objects. + //These algorithms are used both then the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter). + RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported,omitempty"` + + //RequestObjectEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for Request Objects. + //These algorithms are used both when the Request Object is passed by value and by reference. + RequestObjectEncryptionAlgValuesSupported []string `json:"request_object_encryption_alg_values_supported,omitempty"` + + //RequestObjectEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for Request Objects. + //These algorithms are used both when the Request Object is passed by value and by reference. + RequestObjectEncryptionEncValuesSupported []string `json:"request_object_encryption_enc_values_supported,omitempty"` + + //TokenEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Token Endpoint. If omitted, the default is client_secret_basic. + TokenEndpointAuthMethodsSupported []AuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"` + + //TokenEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Token Endpoint + //for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt. + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + + //RevocationEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Revocation Endpoint. If omitted, the default is client_secret_basic. + RevocationEndpointAuthMethodsSupported []AuthMethod `json:"revocation_endpoint_auth_methods_supported,omitempty"` + + //RevocationEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint + //for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt. + RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` + + //IntrospectionEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Introspection Endpoint. + IntrospectionEndpointAuthMethodsSupported []AuthMethod `json:"introspection_endpoint_auth_methods_supported,omitempty"` + + //IntrospectionEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint + //for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt. + IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` + + //DisplayValuesSupported contains a list of display parameter values that the OP supports (page, popup, touch, wap). + DisplayValuesSupported []Display `json:"display_values_supported,omitempty"` + + //ClaimTypesSupported contains a list of Claim Types that the OP supports (normal, aggregated, distributed). If omitted, the default is normal Claims. + ClaimTypesSupported []string `json:"claim_types_supported,omitempty"` + + //ClaimsSupported contains a list of Claim Names the OP may be able to supply values for. This list might not be exhaustive. + ClaimsSupported []string `json:"claims_supported,omitempty"` + + //ClaimsParameterSupported specifies whether the OP supports use of the `claims` parameter. If omitted, the default is false. + ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"` + + //CodeChallengeMethodsSupported contains a list of Proof Key for Code Exchange (PKCE) code challenge methods supported by the OP. + CodeChallengeMethodsSupported []CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"` + + //ServiceDocumentation is a URL where developers can get information about the OP and its usage. + ServiceDocumentation string `json:"service_documentation,omitempty"` + + //ClaimsLocalesSupported contains a list of BCP47 language tag values that the OP supports for values of Claims returned. + ClaimsLocalesSupported []language.Tag `json:"claims_locales_supported,omitempty"` + + //UILocalesSupported contains a list of BCP47 language tag values that the OP supports for the user interface. + UILocalesSupported []language.Tag `json:"ui_locales_supported,omitempty"` + + //RequestParameterSupported specifies whether the OP supports use of the `request` parameter. If omitted, the default value is false. + RequestParameterSupported bool `json:"request_parameter_supported,omitempty"` + + //RequestURIParameterSupported specifies whether the OP supports use of the `request_uri` parameter. If omitted, the default value is true. (therefore no omitempty) + RequestURIParameterSupported bool `json:"request_uri_parameter_supported"` + + //RequireRequestURIRegistration specifies whether the OP requires any `request_uri` to be pre-registered using the request_uris registration parameter. If omitted, the default value is false. + RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"` + + //OPPolicyURI is a URL the OP provides to the person registering the Client to read about the OP's requirements on how the RP can use the data provided by the OP. + OPPolicyURI string `json:"op_policy_uri,omitempty"` + + //OPTermsOfServiceURI is a URL the OpenID Provider provides to the person registering the Client to read about OpenID Provider's terms of service. + OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"` } type AuthMethod string diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go new file mode 100644 index 0000000..5797a59 --- /dev/null +++ b/pkg/oidc/error.go @@ -0,0 +1,139 @@ +package oidc + +import ( + "errors" + "fmt" +) + +type errorType string + +const ( + InvalidRequest errorType = "invalid_request" + InvalidScope errorType = "invalid_scope" + InvalidClient errorType = "invalid_client" + InvalidGrant errorType = "invalid_grant" + UnauthorizedClient errorType = "unauthorized_client" + UnsupportedGrantType errorType = "unsupported_grant_type" + ServerError errorType = "server_error" + InteractionRequired errorType = "interaction_required" + LoginRequired errorType = "login_required" + RequestNotSupported errorType = "request_not_supported" +) + +var ( + ErrInvalidRequest = func() *Error { + return &Error{ + ErrorType: InvalidRequest, + } + } + ErrInvalidRequestRedirectURI = func() *Error { + return &Error{ + ErrorType: InvalidRequest, + redirectDisabled: true, + } + } + ErrInvalidScope = func() *Error { + return &Error{ + ErrorType: InvalidScope, + } + } + ErrInvalidClient = func() *Error { + return &Error{ + ErrorType: InvalidClient, + } + } + ErrInvalidGrant = func() *Error { + return &Error{ + ErrorType: InvalidGrant, + } + } + ErrUnauthorizedClient = func() *Error { + return &Error{ + ErrorType: UnauthorizedClient, + } + } + ErrUnsupportedGrantType = func() *Error { + return &Error{ + ErrorType: UnsupportedGrantType, + } + } + ErrServerError = func() *Error { + return &Error{ + ErrorType: ServerError, + } + } + ErrInteractionRequired = func() *Error { + return &Error{ + ErrorType: InteractionRequired, + } + } + ErrLoginRequired = func() *Error { + return &Error{ + ErrorType: LoginRequired, + } + } + ErrRequestNotSupported = func() *Error { + return &Error{ + ErrorType: RequestNotSupported, + } + } +) + +type Error struct { + Parent error `json:"-" schema:"-"` + ErrorType errorType `json:"error" schema:"error"` + Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` + State string `json:"state,omitempty" schema:"state,omitempty"` + redirectDisabled bool `schema:"-"` +} + +func (e *Error) Error() string { + message := "ErrorType=" + string(e.ErrorType) + if e.Description != "" { + message += " Description=" + e.Description + } + if e.Parent != nil { + message += " Parent=" + e.Parent.Error() + } + return message +} + +func (e *Error) Unwrap() error { + return e.Parent +} + +func (e *Error) Is(target error) bool { + t, ok := target.(*Error) + if !ok { + return false + } + return e.ErrorType == t.ErrorType && + (e.Description == t.Description || t.Description == "") && + (e.State == t.State || t.State == "") +} + +func (e *Error) WithParent(err error) *Error { + e.Parent = err + return e +} + +func (e *Error) WithDescription(desc string, args ...interface{}) *Error { + e.Description = fmt.Sprintf(desc, args...) + return e +} + +func (e *Error) IsRedirectDisabled() bool { + return e.redirectDisabled +} + +// DefaultToServerError checks if the error is an Error +// if not the provided error will be wrapped into a ServerError +func DefaultToServerError(err error, description string) *Error { + oauth := new(Error) + if ok := errors.As(err, &oauth); !ok { + oauth.ErrorType = ServerError + oauth.Description = description + oauth.Parent = err + } + return oauth +} diff --git a/pkg/oidc/introspection.go b/pkg/oidc/introspection.go index 8dd1987..6ac2986 100644 --- a/pkg/oidc/introspection.go +++ b/pkg/oidc/introspection.go @@ -42,181 +42,181 @@ type introspectionResponse struct { claims map[string]interface{} } -func (u *introspectionResponse) IsActive() bool { - return u.Active +func (i *introspectionResponse) IsActive() bool { + return i.Active } -func (u *introspectionResponse) SetScopes(scope []string) { - u.Scope = scope +func (i *introspectionResponse) SetScopes(scope []string) { + i.Scope = scope } -func (u *introspectionResponse) SetClientID(id string) { - u.ClientID = id +func (i *introspectionResponse) SetClientID(id string) { + i.ClientID = id } -func (u *introspectionResponse) GetSubject() string { - return u.Subject +func (i *introspectionResponse) GetSubject() string { + return i.Subject } -func (u *introspectionResponse) GetName() string { - return u.Name +func (i *introspectionResponse) GetName() string { + return i.Name } -func (u *introspectionResponse) GetGivenName() string { - return u.GivenName +func (i *introspectionResponse) GetGivenName() string { + return i.GivenName } -func (u *introspectionResponse) GetFamilyName() string { - return u.FamilyName +func (i *introspectionResponse) GetFamilyName() string { + return i.FamilyName } -func (u *introspectionResponse) GetMiddleName() string { - return u.MiddleName +func (i *introspectionResponse) GetMiddleName() string { + return i.MiddleName } -func (u *introspectionResponse) GetNickname() string { - return u.Nickname +func (i *introspectionResponse) GetNickname() string { + return i.Nickname } -func (u *introspectionResponse) GetProfile() string { - return u.Profile +func (i *introspectionResponse) GetProfile() string { + return i.Profile } -func (u *introspectionResponse) GetPicture() string { - return u.Picture +func (i *introspectionResponse) GetPicture() string { + return i.Picture } -func (u *introspectionResponse) GetWebsite() string { - return u.Website +func (i *introspectionResponse) GetWebsite() string { + return i.Website } -func (u *introspectionResponse) GetGender() Gender { - return u.Gender +func (i *introspectionResponse) GetGender() Gender { + return i.Gender } -func (u *introspectionResponse) GetBirthdate() string { - return u.Birthdate +func (i *introspectionResponse) GetBirthdate() string { + return i.Birthdate } -func (u *introspectionResponse) GetZoneinfo() string { - return u.Zoneinfo +func (i *introspectionResponse) GetZoneinfo() string { + return i.Zoneinfo } -func (u *introspectionResponse) GetLocale() language.Tag { - return u.Locale +func (i *introspectionResponse) GetLocale() language.Tag { + return i.Locale } -func (u *introspectionResponse) GetPreferredUsername() string { - return u.PreferredUsername +func (i *introspectionResponse) GetPreferredUsername() string { + return i.PreferredUsername } -func (u *introspectionResponse) GetEmail() string { - return u.Email +func (i *introspectionResponse) GetEmail() string { + return i.Email } -func (u *introspectionResponse) IsEmailVerified() bool { - return bool(u.EmailVerified) +func (i *introspectionResponse) IsEmailVerified() bool { + return bool(i.EmailVerified) } -func (u *introspectionResponse) GetPhoneNumber() string { - return u.PhoneNumber +func (i *introspectionResponse) GetPhoneNumber() string { + return i.PhoneNumber } -func (u *introspectionResponse) IsPhoneNumberVerified() bool { - return u.PhoneNumberVerified +func (i *introspectionResponse) IsPhoneNumberVerified() bool { + return i.PhoneNumberVerified } -func (u *introspectionResponse) GetAddress() UserInfoAddress { - return u.Address +func (i *introspectionResponse) GetAddress() UserInfoAddress { + return i.Address } -func (u *introspectionResponse) GetClaim(key string) interface{} { - return u.claims[key] +func (i *introspectionResponse) GetClaim(key string) interface{} { + return i.claims[key] } -func (u *introspectionResponse) SetActive(active bool) { - u.Active = active +func (i *introspectionResponse) SetActive(active bool) { + i.Active = active } -func (u *introspectionResponse) SetSubject(sub string) { - u.Subject = sub +func (i *introspectionResponse) SetSubject(sub string) { + i.Subject = sub } -func (u *introspectionResponse) SetName(name string) { - u.Name = name +func (i *introspectionResponse) SetName(name string) { + i.Name = name } -func (u *introspectionResponse) SetGivenName(name string) { - u.GivenName = name +func (i *introspectionResponse) SetGivenName(name string) { + i.GivenName = name } -func (u *introspectionResponse) SetFamilyName(name string) { - u.FamilyName = name +func (i *introspectionResponse) SetFamilyName(name string) { + i.FamilyName = name } -func (u *introspectionResponse) SetMiddleName(name string) { - u.MiddleName = name +func (i *introspectionResponse) SetMiddleName(name string) { + i.MiddleName = name } -func (u *introspectionResponse) SetNickname(name string) { - u.Nickname = name +func (i *introspectionResponse) SetNickname(name string) { + i.Nickname = name } -func (u *introspectionResponse) SetUpdatedAt(date time.Time) { - u.UpdatedAt = Time(date) +func (i *introspectionResponse) SetUpdatedAt(date time.Time) { + i.UpdatedAt = Time(date) } -func (u *introspectionResponse) SetProfile(profile string) { - u.Profile = profile +func (i *introspectionResponse) SetProfile(profile string) { + i.Profile = profile } -func (u *introspectionResponse) SetPicture(picture string) { - u.Picture = picture +func (i *introspectionResponse) SetPicture(picture string) { + i.Picture = picture } -func (u *introspectionResponse) SetWebsite(website string) { - u.Website = website +func (i *introspectionResponse) SetWebsite(website string) { + i.Website = website } -func (u *introspectionResponse) SetGender(gender Gender) { - u.Gender = gender +func (i *introspectionResponse) SetGender(gender Gender) { + i.Gender = gender } -func (u *introspectionResponse) SetBirthdate(birthdate string) { - u.Birthdate = birthdate +func (i *introspectionResponse) SetBirthdate(birthdate string) { + i.Birthdate = birthdate } -func (u *introspectionResponse) SetZoneinfo(zoneInfo string) { - u.Zoneinfo = zoneInfo +func (i *introspectionResponse) SetZoneinfo(zoneInfo string) { + i.Zoneinfo = zoneInfo } -func (u *introspectionResponse) SetLocale(locale language.Tag) { - u.Locale = locale +func (i *introspectionResponse) SetLocale(locale language.Tag) { + i.Locale = locale } -func (u *introspectionResponse) SetPreferredUsername(name string) { - u.PreferredUsername = name +func (i *introspectionResponse) SetPreferredUsername(name string) { + i.PreferredUsername = name } -func (u *introspectionResponse) SetEmail(email string, verified bool) { - u.Email = email - u.EmailVerified = boolString(verified) +func (i *introspectionResponse) SetEmail(email string, verified bool) { + i.Email = email + i.EmailVerified = boolString(verified) } -func (u *introspectionResponse) SetPhone(phone string, verified bool) { - u.PhoneNumber = phone - u.PhoneNumberVerified = verified +func (i *introspectionResponse) SetPhone(phone string, verified bool) { + i.PhoneNumber = phone + i.PhoneNumberVerified = verified } -func (u *introspectionResponse) SetAddress(address UserInfoAddress) { - u.Address = address +func (i *introspectionResponse) SetAddress(address UserInfoAddress) { + i.Address = address } -func (u *introspectionResponse) AppendClaims(key string, value interface{}) { - if u.claims == nil { - u.claims = make(map[string]interface{}) +func (i *introspectionResponse) AppendClaims(key string, value interface{}) { + if i.claims == nil { + i.claims = make(map[string]interface{}) } - u.claims[key] = value + i.claims[key] = value } func (i *introspectionResponse) MarshalJSON() ([]byte, error) { diff --git a/pkg/oidc/revocation.go b/pkg/oidc/revocation.go new file mode 100644 index 0000000..0a56c61 --- /dev/null +++ b/pkg/oidc/revocation.go @@ -0,0 +1,6 @@ +package oidc + +type RevocationRequest struct { + Token string `schema:"token"` + TokenTypeHint string `schema:"token_type_hint"` +} diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index f753120..e34543e 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -1,10 +1,7 @@ package oidc import ( - "crypto/rsa" - "crypto/x509" "encoding/json" - "encoding/pem" "fmt" "io/ioutil" "time" @@ -12,7 +9,8 @@ import ( "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/crypto" + "github.com/caos/oidc/pkg/http" ) const ( @@ -188,7 +186,7 @@ func (a *accessTokenClaims) MarshalJSON() ([]byte, error) { if err != nil { return nil, err } - return utils.ConcatenateJSON(b, info) + return http.ConcatenateJSON(b, info) } func (a *accessTokenClaims) UnmarshalJSON(data []byte) error { @@ -325,7 +323,7 @@ func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { return t.signatureAlg } -//SetSignatureAlgorithm implements the IDTokenClaims interface +//SetAccessTokenHash implements the IDTokenClaims interface func (t *idTokenClaims) SetAccessTokenHash(hash string) { t.AccessTokenHash = hash } @@ -375,7 +373,7 @@ func (t *idTokenClaims) MarshalJSON() ([]byte, error) { if err != nil { return nil, err } - return utils.ConcatenateJSON(b, info) + return http.ConcatenateJSON(b, info) } func (t *idTokenClaims) UnmarshalJSON(data []byte) error { @@ -572,12 +570,12 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, } func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { - hash, err := utils.GetHashAlgorithm(sigAlgorithm) + hash, err := crypto.GetHashAlgorithm(sigAlgorithm) if err != nil { return "", err } - return utils.HashString(hash, claim, true), nil + return crypto.HashString(hash, claim, true), nil } func AppendClientIDToAudience(clientID string, audience []string) []string { @@ -590,7 +588,7 @@ func AppendClientIDToAudience(clientID string, audience []string) []string { } func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) { - privateKey, err := bytesToPrivateKey(assertion.GetPrivateKey()) + privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey()) if err != nil { return "", err } @@ -613,21 +611,3 @@ func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error } return signedAssertion.CompactSerialize() } - -func bytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { - block, _ := pem.Decode(priv) - enc := x509.IsEncryptedPEMBlock(block) - b := block.Bytes - var err error - if enc { - b, err = x509.DecryptPEMBlock(block, nil) - if err != nil { - return nil, err - } - } - key, err := x509.ParsePKCS1PrivateKey(b) - if err != nil { - return nil, err - } - return key, nil -} diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 6f9f1af..f260f32 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -12,7 +12,7 @@ const ( //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow GrantTypeCode GrantType = "authorization_code" - //GrantTypeCode defines the grant_type `refresh_token` used for the Token Request in the Refresh Token Flow + //GrantTypeRefreshToken defines the grant_type `refresh_token` used for the Token Request in the Refresh Token Flow GrantTypeRefreshToken GrantType = "refresh_token" //GrantTypeBearer defines the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant @@ -183,7 +183,7 @@ func (j *JWTTokenRequest) GetSubject() string { return j.Subject } -//GetSubject implements the TokenRequest interface +//GetScopes implements the TokenRequest interface func (j *JWTTokenRequest) GetScopes() []string { return j.Scopes } diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index e72d67c..b6a75f4 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -6,6 +6,7 @@ import ( "time" "golang.org/x/text/language" + "gopkg.in/square/go-jose.v2" ) type Audience []string @@ -66,6 +67,8 @@ type Prompt SpaceDelimitedArray type ResponseType string +type ResponseMode string + func (s SpaceDelimitedArray) Encode() string { return strings.Join(s, " ") } @@ -106,3 +109,16 @@ func (t *Time) UnmarshalJSON(data []byte) error { func (t *Time) MarshalJSON() ([]byte, error) { return json.Marshal(time.Time(*t).UTC().Unix()) } + +type RequestObject struct { + Issuer string `json:"iss"` + Audience Audience `json:"aud"` + AuthRequest +} + +func (r *RequestObject) GetIssuer() string { + return r.Issuer +} + +func (r *RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { +} diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index 2272421..afc2ad0 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -339,20 +339,20 @@ func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, fo } } -func (i *userinfo) MarshalJSON() ([]byte, error) { +func (u *userinfo) MarshalJSON() ([]byte, error) { type Alias userinfo a := &struct { *Alias Locale interface{} `json:"locale,omitempty"` UpdatedAt int64 `json:"updated_at,omitempty"` }{ - Alias: (*Alias)(i), + Alias: (*Alias)(u), } - if !i.Locale.IsRoot() { - a.Locale = i.Locale + if !u.Locale.IsRoot() { + a.Locale = u.Locale } - if !time.Time(i.UpdatedAt).IsZero() { - a.UpdatedAt = time.Time(i.UpdatedAt).Unix() + if !time.Time(u.UpdatedAt).IsZero() { + a.UpdatedAt = time.Time(u.UpdatedAt).Unix() } b, err := json.Marshal(a) @@ -360,34 +360,34 @@ func (i *userinfo) MarshalJSON() ([]byte, error) { return nil, err } - if len(i.claims) == 0 { + if len(u.claims) == 0 { return b, nil } - err = json.Unmarshal(b, &i.claims) + err = json.Unmarshal(b, &u.claims) if err != nil { - return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims) + return nil, fmt.Errorf("jws: invalid map of custom claims %v", u.claims) } - return json.Marshal(i.claims) + return json.Marshal(u.claims) } -func (i *userinfo) UnmarshalJSON(data []byte) error { +func (u *userinfo) UnmarshalJSON(data []byte) error { type Alias userinfo a := &struct { Address *userInfoAddress `json:"address,omitempty"` *Alias UpdatedAt int64 `json:"update_at,omitempty"` }{ - Alias: (*Alias)(i), + Alias: (*Alias)(u), } if err := json.Unmarshal(data, &a); err != nil { return err } - i.Address = a.Address - i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) + u.Address = a.Address + u.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) - if err := json.Unmarshal(data, &i.claims); err != nil { + if err := json.Unmarshal(data, &u.claims); err != nil { return err } diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 4284d17..9f5335d 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -12,7 +12,7 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/caos/oidc/pkg/utils" + str "github.com/caos/oidc/pkg/strings" ) type Claims interface { @@ -25,6 +25,10 @@ type Claims interface { GetAuthenticationContextClassReference() string GetAuthTime() time.Time GetAuthorizedParty() string + ClaimsSignature +} + +type ClaimsSignature interface { SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) } @@ -61,10 +65,10 @@ type Verifier interface { type ACRVerifier func(string) error //DefaultACRVerifier implements `ACRVerifier` returning an error -//if non of the provided values matches the acr claim +//if none of the provided values matches the acr claim func DefaultACRVerifier(possibleValues []string) ACRVerifier { return func(acr string) error { - if !utils.Contains(possibleValues, acr) { + if !str.Contains(possibleValues, acr) { return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr) } return nil @@ -103,7 +107,7 @@ func CheckIssuer(claims Claims, issuer string) error { } func CheckAudience(claims Claims, clientID string) error { - if !utils.Contains(claims.GetAudience(), clientID) { + if !str.Contains(claims.GetAudience(), clientID) { return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID) } @@ -123,7 +127,7 @@ func CheckAuthorizedParty(claims Claims, clientID string) error { return nil } -func CheckSignature(ctx context.Context, token string, payload []byte, claims Claims, supportedSigAlgs []string, set KeySet) error { +func CheckSignature(ctx context.Context, token string, payload []byte, claims ClaimsSignature, supportedSigAlgs []string, set KeySet) error { jws, err := jose.ParseSigned(token) if err != nil { return ErrParse @@ -138,7 +142,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl if len(supportedSigAlgs) == 0 { supportedSigAlgs = []string{"RS256"} } - if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) { + if !str.Contains(supportedSigAlgs, sig.Header.Algorithm) { return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm) } diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index fce681f..909b8b0 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -2,7 +2,6 @@ package op import ( "context" - "fmt" "net" "net/http" "net/url" @@ -11,8 +10,9 @@ import ( "github.com/gorilla/mux" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" + str "github.com/caos/oidc/pkg/strings" ) type AuthRequest interface { @@ -26,6 +26,7 @@ type AuthRequest interface { GetNonce() string GetRedirectURI() string GetResponseType() oidc.ResponseType + GetResponseMode() oidc.ResponseMode GetScopes() []string GetState() string GetSubject() string @@ -34,16 +35,17 @@ type AuthRequest interface { type Authorizer interface { Storage() Storage - Decoder() utils.Decoder - Encoder() utils.Encoder + Decoder() httphelper.Decoder + Encoder() httphelper.Encoder Signer() Signer IDTokenHintVerifier() IDTokenHintVerifier Crypto() Crypto Issuer() string + RequestObjectSupported() bool } //AuthorizeValidator is an extension of Authorizer interface -//implementing it's own validation mechanism for the auth request +//implementing its own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) @@ -69,6 +71,13 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } + if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { + authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + } validation := ValidateAuthRequest if validater, ok := authorizer.(AuthorizeValidator); ok { validation = validater.ValidateAuthRequest @@ -78,33 +87,114 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } + if authReq.RequestParam != "" { + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + return + } req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) return } client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) if err != nil { - AuthRequestError(w, r, req, err, authorizer.Encoder()) + AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) return } RedirectToLogin(req.GetID(), client, w, r) } -//ParseAuthorizeRequest parsed the http request into a oidc.AuthRequest -func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { +//ParseAuthorizeRequest parsed the http request into an oidc.AuthRequest +func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AuthRequest, error) { err := r.ParseForm() if err != nil { - return nil, ErrInvalidRequest("cannot parse form") + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } authReq := new(oidc.AuthRequest) err = decoder.Decode(authReq, r.Form) if err != nil { - return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)) + return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse auth request").WithParent(err) } return authReq, nil } +//ParseRequestObject parse the `request` parameter, validates the token including the signature +//and copies the token claims into the auth request +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { + requestObject := new(oidc.RequestObject) + payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) + if err != nil { + return nil, err + } + + if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { + return authReq, oidc.ErrInvalidRequest() + } + if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { + return authReq, oidc.ErrInvalidRequest() + } + if requestObject.Issuer != requestObject.ClientID { + return authReq, oidc.ErrInvalidRequest() + } + if !str.Contains(requestObject.Audience, issuer) { + return authReq, oidc.ErrInvalidRequest() + } + keySet := &jwtProfileKeySet{storage, requestObject.Issuer} + if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { + return authReq, err + } + CopyRequestObjectToAuthRequest(authReq, requestObject) + return authReq, nil +} + +//CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request +//and clears the `RequestParam` of the auth request +func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) { + if str.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 { + authReq.Scopes = requestObject.Scopes + } + if requestObject.RedirectURI != "" { + authReq.RedirectURI = requestObject.RedirectURI + } + if requestObject.State != "" { + authReq.State = requestObject.State + } + if requestObject.ResponseMode != "" { + authReq.ResponseMode = requestObject.ResponseMode + } + if requestObject.Nonce != "" { + authReq.Nonce = requestObject.Nonce + } + if requestObject.Display != "" { + authReq.Display = requestObject.Display + } + if len(requestObject.Prompt) > 0 { + authReq.Prompt = requestObject.Prompt + } + if requestObject.MaxAge != nil { + authReq.MaxAge = requestObject.MaxAge + } + if len(requestObject.UILocales) > 0 { + authReq.UILocales = requestObject.UILocales + } + if requestObject.IDTokenHint != "" { + authReq.IDTokenHint = requestObject.IDTokenHint + } + if requestObject.LoginHint != "" { + authReq.LoginHint = requestObject.LoginHint + } + if len(requestObject.ACRValues) > 0 { + authReq.ACRValues = requestObject.ACRValues + } + if requestObject.CodeChallenge != "" { + authReq.CodeChallenge = requestObject.CodeChallenge + } + if requestObject.CodeChallengeMethod != "" { + authReq.CodeChallengeMethod = requestObject.CodeChallengeMethod + } + authReq.RequestParam = "" +} + //ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) @@ -113,7 +203,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage } client, err := storage.GetClientByClientID(ctx, authReq.ClientID) if err != nil { - return "", ErrServerError(err.Error()) + return "", oidc.DefaultToServerError(err, "unable to retrieve client by id") } authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) if err != nil { @@ -132,7 +222,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) { for _, prompt := range prompts { if prompt == oidc.PromptNone && len(prompts) > 1 { - return nil, ErrInvalidRequest("The prompt parameter `none` must only be used as a single value") + return nil, oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value") } if prompt == oidc.PromptLogin { maxAge = oidc.NewMaxAge(0) @@ -144,7 +234,9 @@ func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) //ValidateAuthReqScopes validates the passed scopes func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { if len(scopes) == 0 { - return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") + return nil, oidc.ErrInvalidRequest(). + WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + + "If you have any questions, you may contact the administrator of the application.") } openID := false for i := len(scopes) - 1; i >= 0; i-- { @@ -165,7 +257,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { } } if !openID { - return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") + return nil, oidc.ErrInvalidScope().WithDescription("The scope openid is missing in your request. " + + "Please ensure the scope openid is added to the request. " + + "If you have any questions, you may contact the administrator of the application.") } return scopes, nil @@ -174,19 +268,23 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error { if uri == "" { - return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " + + "Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") } if strings.HasPrefix(uri, "https://") { - if !utils.Contains(client.RedirectURIs(), uri) { - return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") + if !str.Contains(client.RedirectURIs(), uri) { + return oidc.ErrInvalidRequestRedirectURI(). + WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application.") } return nil } if client.ApplicationType() == ApplicationTypeNative { return validateAuthReqRedirectURINative(client, uri, responseType) } - if !utils.Contains(client.RedirectURIs(), uri) { - return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") + if !str.Contains(client.RedirectURIs(), uri) { + return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application.") } if strings.HasPrefix(uri, "http://") { if client.DevMode() { @@ -195,23 +293,27 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) { return nil } - return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " + + "If you have any questions, you may contact the administrator of the application.") } - return ErrInvalidRequest("This client's redirect_uri is using a custom schema and is not allowed. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is using a custom schema and is not allowed. " + + "If you have any questions, you may contact the administrator of the application.") } //ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) isCustomSchema := !strings.HasPrefix(uri, "http://") - if utils.Contains(client.RedirectURIs(), uri) { + if str.Contains(client.RedirectURIs(), uri) { if isLoopback || isCustomSchema { return nil } - return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " + + "If you have any questions, you may contact the administrator of the application.") } if !isLoopback { - return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application.") } for _, uri := range client.RedirectURIs() { redirectURI, ok := HTTPLoopbackOrLocalhost(uri) @@ -219,7 +321,8 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi return nil } } - return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration." + + " If you have any questions, you may contact the administrator of the application.") } func equalURI(url1, url2 *url.URL) bool { @@ -241,10 +344,12 @@ func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) { //ValidateAuthReqResponseType validates the passed response_type to the registered response types func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error { if responseType == "" { - return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrInvalidRequest().WithDescription("The response type is missing in your request. " + + "If you have any questions, you may contact the administrator of the application.") } if !ContainsResponseType(client.ResponseTypes(), responseType) { - return ErrInvalidRequest("The requested response type is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") + return oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application.") } return nil } @@ -257,7 +362,8 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie } claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier) if err != nil { - return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") + return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " + + "If you have any questions, you may contact the administrator of the application.") } return claims.GetSubject(), nil } @@ -279,7 +385,9 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author return } if !authReq.Done() { - AuthRequestError(w, r, authReq, ErrInteractionRequired("Unfortunately, the user may is not logged in and/or additional interaction is required."), authorizer.Encoder()) + AuthRequestError(w, r, authReq, + oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), + authorizer.Encoder()) return } AuthResponse(authReq, authorizer, w, r) @@ -306,9 +414,17 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) - if authReq.GetState() != "" { - callback = callback + "&state=" + authReq.GetState() + codeResponse := struct { + code string + state string + }{ + code: code, + state: authReq.GetState(), + } + callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return } http.Redirect(w, r, callback, http.StatusFound) } @@ -321,12 +437,11 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) + callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { AuthRequestError(w, r, authReq, err, authorizer.Encoder()) return } - callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) http.Redirect(w, r, callback, http.StatusFound) } @@ -346,3 +461,22 @@ func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Sto func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { return crypto.Encrypt(authReq.GetID()) } + +//AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values +//depending on the response_mode and response_type +func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response interface{}, encoder httphelper.Encoder) (string, error) { + params, err := httphelper.URLEncodeResponse(response, encoder) + if err != nil { + return "", oidc.ErrServerError().WithParent(err) + } + if responseMode == oidc.ResponseModeQuery { + return redirectURI + "?" + params, nil + } + if responseMode == oidc.ResponseModeFragment { + return redirectURI + "#" + params, nil + } + if responseType == "" || responseType == oidc.ResponseTypeCode { + return redirectURI + "?" + params, nil + } + return redirectURI + "#" + params, nil +} diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 40e1a8a..7259ec7 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -1,6 +1,8 @@ package op_test import ( + "context" + "errors" "net/http" "net/http/httptest" "net/url" @@ -11,10 +13,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op/mock" - "github.com/caos/oidc/pkg/utils" ) // @@ -75,7 +77,7 @@ import ( func TestParseAuthorizeRequest(t *testing.T) { type args struct { r *http.Request - decoder utils.Decoder + decoder httphelper.Decoder } type res struct { want *oidc.AuthRequest @@ -101,7 +103,7 @@ func TestParseAuthorizeRequest(t *testing.T) { "decoding error", args{ &http.Request{URL: &url.URL{RawQuery: "unknown=value"}}, - func() utils.Decoder { + func() httphelper.Decoder { decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(false) return decoder @@ -116,7 +118,7 @@ func TestParseAuthorizeRequest(t *testing.T) { "parsing ok", args{ &http.Request{URL: &url.URL{RawQuery: "scope=openid"}}, - func() utils.Decoder { + func() httphelper.Decoder { decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(false) return decoder @@ -150,44 +152,138 @@ func TestValidateAuthRequest(t *testing.T) { tests := []struct { name string args args - wantErr bool + wantErr error }{ - //TODO: - // { - // "oauth2 spec" - // } { "scope missing fails", args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil}, - true, + oidc.ErrInvalidRequest(), }, { "scope openid missing fails", args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil}, - true, + oidc.ErrInvalidScope(), }, { "response_type missing fails", args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil}, - true, + oidc.ErrInvalidRequest(), }, { "client_id missing fails", args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil}, - true, + oidc.ErrInvalidRequest(), }, { "redirect_uri missing fails", args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil}, - true, + oidc.ErrInvalidRequest(), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr) + _, err := op.ValidateAuthRequest(context.TODO(), tt.args.authRequest, tt.args.storage, tt.args.verifier) + if tt.wantErr == nil && err != nil { + t.Errorf("ValidateAuthRequest() unexpected error = %v", err) } + if tt.wantErr != nil && !errors.Is(err, tt.wantErr) { + t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateAuthReqPrompt(t *testing.T) { + type args struct { + prompts []string + maxAge *uint + } + type res struct { + maxAge *uint + err error + } + tests := []struct { + name string + args args + res res + }{ + { + "no prompts and maxAge, ok", + args{ + nil, + nil, + }, + res{ + nil, + nil, + }, + }, + { + "no prompts but maxAge, ok", + args{ + nil, + oidc.NewMaxAge(10), + }, + res{ + oidc.NewMaxAge(10), + nil, + }, + }, + { + "prompt none, ok", + args{ + []string{"none"}, + oidc.NewMaxAge(10), + }, + res{ + oidc.NewMaxAge(10), + nil, + }, + }, + { + "prompt none with others, err", + args{ + []string{"none", "login"}, + oidc.NewMaxAge(10), + }, + res{ + nil, + oidc.ErrInvalidRequest(), + }, + }, + { + "prompt login, ok", + args{ + []string{"login"}, + nil, + }, + res{ + oidc.NewMaxAge(0), + nil, + }, + }, + { + "prompt login with maxAge, ok", + args{ + []string{"login"}, + oidc.NewMaxAge(10), + }, + res{ + oidc.NewMaxAge(0), + nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + maxAge, err := op.ValidateAuthReqPrompt(tt.args.prompts, tt.args.maxAge) + if tt.res.err == nil && err != nil { + t.Errorf("ValidateAuthRequest() unexpected error = %v", err) + } + if tt.res.err != nil && !errors.Is(err, tt.res.err) { + t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err) + } + assert.Equal(t, tt.res.maxAge, maxAge) }) } } @@ -465,6 +561,80 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { } } +func TestLoopbackOrLocalhost(t *testing.T) { + type args struct { + url string + } + tests := []struct { + name string + args args + want bool + }{ + { + "not parsable, false", + args{url: string('\n')}, + false, + }, + { + "not http, false", + args{url: "localhost/test"}, + false, + }, + { + "not http, false", + args{url: "http://localhost.com/test"}, + false, + }, + { + "v4 no port ok", + args{url: "http://127.0.0.1/test"}, + true, + }, + { + "v6 short no port ok", + args{url: "http://[::1]/test"}, + true, + }, + { + "v6 long no port ok", + args{url: "http://[0:0:0:0:0:0:0:1]/test"}, + true, + }, + { + "locahost no port ok", + args{url: "http://localhost/test"}, + true, + }, + { + "v4 with port ok", + args{url: "http://127.0.0.1:4200/test"}, + true, + }, + { + "v6 short with port ok", + args{url: "http://[::1]:4200/test"}, + true, + }, + { + "v6 long with port ok", + args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"}, + true, + }, + { + "localhost with port ok", + args{url: "http://localhost:4200/test"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want { + t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want) + } + }) + } +} + func TestValidateAuthReqResponseType(t *testing.T) { type args struct { responseType oidc.ResponseType @@ -534,100 +704,122 @@ func TestRedirectToLogin(t *testing.T) { } } -func TestAuthorizeCallback(t *testing.T) { +func TestAuthResponseURL(t *testing.T) { type args struct { - w http.ResponseWriter - r *http.Request - authorizer op.Authorizer + redirectURI string + responseType oidc.ResponseType + responseMode oidc.ResponseMode + response interface{} + encoder httphelper.Encoder } - tests := []struct { - name string - args args - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer) - }) - } -} - -func TestAuthResponse(t *testing.T) { - type args struct { - authReq op.AuthRequest - authorizer op.Authorizer - w http.ResponseWriter - r *http.Request - } - tests := []struct { - name string - args args - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r) - }) - } -} - -func Test_LoopbackOrLocalhost(t *testing.T) { - type args struct { + type res struct { url string + err error } tests := []struct { name string args args - want bool + res res }{ { - "v4 no port ok", - args{url: "http://127.0.0.1/test"}, - true, + "encoding error", + args{ + "uri", + oidc.ResponseTypeCode, + "", + map[string]interface{}{"test": "test"}, + &mockEncoder{ + errors.New("error encoding"), + }, + }, + res{ + "", + oidc.ErrServerError(), + }, }, { - "v6 short no port ok", - args{url: "http://[::1]/test"}, - true, + "response mode query", + args{ + "uri", + oidc.ResponseTypeIDToken, + oidc.ResponseModeQuery, + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?test=test", + nil, + }, }, { - "v6 long no port ok", - args{url: "http://[0:0:0:0:0:0:0:1]/test"}, - true, + "response mode fragment", + args{ + "uri", + oidc.ResponseTypeCode, + oidc.ResponseModeFragment, + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri#test=test", + nil, + }, }, { - "locahost no port ok", - args{url: "http://localhost/test"}, - true, + "response type code", + args{ + "uri", + oidc.ResponseTypeCode, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri?test=test", + nil, + }, }, { - "v4 with port ok", - args{url: "http://127.0.0.1:4200/test"}, - true, - }, - { - "v6 short with port ok", - args{url: "http://[::1]:4200/test"}, - true, - }, - { - "v6 long with port ok", - args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"}, - true, - }, - { - "localhost with port ok", - args{url: "http://localhost:4200/test"}, - true, + "response type id token", + args{ + "uri", + oidc.ResponseTypeIDToken, + "", + map[string][]string{"test": {"test"}}, + &mockEncoder{}, + }, + res{ + "uri#test=test", + nil, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want { - t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want) + got, err := op.AuthResponseURL(tt.args.redirectURI, tt.args.responseType, tt.args.responseMode, tt.args.response, tt.args.encoder) + if tt.res.err == nil && err != nil { + t.Errorf("ValidateAuthRequest() unexpected error = %v", err) + } + if tt.res.err != nil && !errors.Is(err, tt.res.err) { + t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err) + } + if got != tt.res.url { + t.Errorf("AuthResponseURL() got = %v, want %v", got, tt.res.url) } }) } } + +type mockEncoder struct { + err error +} + +func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { + if m.err != nil { + return m.err + } + for s, strings := range src.(map[string][]string) { + dst[s] = strings + } + return nil +} diff --git a/pkg/op/config.go b/pkg/op/config.go index 39c84c8..527e134 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -16,15 +16,23 @@ type Configuration interface { TokenEndpoint() Endpoint IntrospectionEndpoint() Endpoint UserinfoEndpoint() Endpoint + RevocationEndpoint() Endpoint EndSessionEndpoint() Endpoint KeysEndpoint() Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool AuthMethodPrivateKeyJWTSupported() bool + TokenEndpointSigningAlgorithmsSupported() []string GrantTypeRefreshTokenSupported() bool GrantTypeTokenExchangeSupported() bool GrantTypeJWTAuthorizationSupported() bool + IntrospectionAuthMethodPrivateKeyJWTSupported() bool + IntrospectionEndpointSigningAlgorithmsSupported() []string + RevocationAuthMethodPrivateKeyJWTSupported() bool + RevocationEndpointSigningAlgorithmsSupported() []string + RequestObjectSupported() bool + RequestObjectSigningAlgorithmsSupported() []string SupportedUILocales() []language.Tag } diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index e140074..5029df8 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -61,6 +61,7 @@ func TestValidateIssuer(t *testing.T) { }, } //ensure env is not set + //nolint:errcheck os.Unsetenv(OidcDevMode) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -86,6 +87,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) { false, }, } + //nolint:errcheck os.Setenv(OidcDevMode, "true") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go index e95157d..e9dd67b 100644 --- a/pkg/op/crypto.go +++ b/pkg/op/crypto.go @@ -1,7 +1,7 @@ package op import ( - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/crypto" ) type Crypto interface { @@ -18,9 +18,9 @@ func NewAESCrypto(key [32]byte) Crypto { } func (c *aesCrypto) Encrypt(s string) (string, error) { - return utils.EncryptAES(s, c.key) + return crypto.EncryptAES(s, c.key) } func (c *aesCrypto) Decrypt(s string) (string, error) { - return utils.DecryptAES(s, c.key) + return crypto.DecryptAES(s, c.key) } diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 807aa20..955d0fa 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -3,8 +3,8 @@ package op import ( "net/http" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) { @@ -14,28 +14,35 @@ func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http } func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { - utils.MarshalJSON(w, config) + httphelper.MarshalJSON(w, config) } func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration { return &oidc.DiscoveryConfiguration{ - Issuer: c.Issuer(), - AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()), - TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), - IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()), - UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), - EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()), - JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), - ScopesSupported: Scopes(c), - ResponseTypesSupported: ResponseTypes(c), - GrantTypesSupported: GrantTypes(c), - SubjectTypesSupported: SubjectTypes(c), - IDTokenSigningAlgValuesSupported: SigAlgorithms(s), - TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c), - IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c), - ClaimsSupported: SupportedClaims(c), - CodeChallengeMethodsSupported: CodeChallengeMethods(c), - UILocalesSupported: c.SupportedUILocales(), + Issuer: c.Issuer(), + AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()), + TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), + IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()), + UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), + RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()), + EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()), + JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), + ScopesSupported: Scopes(c), + ResponseTypesSupported: ResponseTypes(c), + GrantTypesSupported: GrantTypes(c), + SubjectTypesSupported: SubjectTypes(c), + IDTokenSigningAlgValuesSupported: SigAlgorithms(s), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c), + ClaimsSupported: SupportedClaims(c), + CodeChallengeMethodsSupported: CodeChallengeMethods(c), + UILocalesSupported: c.SupportedUILocales(), + RequestParameterSupported: c.RequestObjectSupported(), } } @@ -45,6 +52,7 @@ var DefaultSupportedScopes = []string{ oidc.ScopeEmail, oidc.ScopePhone, oidc.ScopeAddress, + oidc.ScopeOfflineAccess, } func Scopes(c Configuration) []string { @@ -127,6 +135,13 @@ func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod { return authMethods } +func TokenSigAlgorithms(c Configuration) []string { + if !c.AuthMethodPrivateKeyJWTSupported() { + return nil + } + return c.TokenEndpointSigningAlgorithmsSupported() +} + func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod { authMethods := []oidc.AuthMethod{ oidc.AuthMethodBasic, @@ -137,6 +152,20 @@ func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod { return authMethods } +func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod { + authMethods := []oidc.AuthMethod{ + oidc.AuthMethodNone, + oidc.AuthMethodBasic, + } + if c.AuthMethodPostSupported() { + authMethods = append(authMethods, oidc.AuthMethodPost) + } + if c.AuthMethodPrivateKeyJWTSupported() { + authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT) + } + return authMethods +} + func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { codeMethods := make([]oidc.CodeChallengeMethod, 0, 1) if c.CodeMethodS256Supported() { @@ -144,3 +173,24 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { } return codeMethods } + +func IntrospectionSigAlgorithms(c Configuration) []string { + if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() { + return nil + } + return c.IntrospectionEndpointSigningAlgorithmsSupported() +} + +func RevocationSigAlgorithms(c Configuration) []string { + if !c.RevocationAuthMethodPrivateKeyJWTSupported() { + return nil + } + return c.RevocationEndpointSigningAlgorithmsSupported() +} + +func RequestObjectSigAlgorithms(c Configuration) []string { + if !c.RequestObjectSupported() { + return nil + } + return c.RequestObjectSigningAlgorithmsSupported() +} diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 4d97a01..1f0663d 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -37,7 +37,10 @@ func TestDiscover(t *testing.T) { op.Discover(tt.args.w, tt.args.config) rec := tt.args.w.(*httptest.ResponseRecorder) require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, `{"issuer":"https://issuer.com","request_uri_parameter_supported":false}`, rec.Body.String()) + require.Equal(t, + `{"issuer":"https://issuer.com","request_uri_parameter_supported":false} +`, + rec.Body.String()) }) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index 06935c9..ea8d368 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,105 +1,46 @@ package op import ( - "fmt" "net/http" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) -const ( - InvalidRequest errorType = "invalid_request" - InvalidRequestURI errorType = "invalid_request_uri" - InteractionRequired errorType = "interaction_required" - ServerError errorType = "server_error" -) - -var ( - ErrInvalidRequest = func(description string) *OAuthError { - return &OAuthError{ - ErrorType: InvalidRequest, - Description: description, - } - } - ErrInvalidRequestRedirectURI = func(description string) *OAuthError { - return &OAuthError{ - ErrorType: InvalidRequestURI, - Description: description, - redirectDisabled: true, - } - } - ErrInteractionRequired = func(description string) *OAuthError { - return &OAuthError{ - ErrorType: InteractionRequired, - Description: description, - } - } - ErrServerError = func(description string) *OAuthError { - return &OAuthError{ - ErrorType: ServerError, - Description: description, - } - } -) - -type errorType string - type ErrAuthRequest interface { GetRedirectURI() string GetResponseType() oidc.ResponseType GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) { +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { if authReq == nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - e, ok := err.(*OAuthError) - if !ok { - e = new(OAuthError) - e.ErrorType = ServerError - e.Description = err.Error() - } - e.State = authReq.GetState() - if authReq.GetRedirectURI() == "" || e.redirectDisabled { + e := oidc.DefaultToServerError(err, err.Error()) + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { http.Error(w, e.Description, http.StatusBadRequest) return } - params, err := utils.URLEncodeResponse(e, encoder) + e.State = authReq.GetState() + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - url := authReq.GetRedirectURI() - responseType := authReq.GetResponseType() - if responseType == "" || responseType == oidc.ResponseTypeCode { - url += "?" + params - } else { - url += "#" + params - } http.Redirect(w, r, url, http.StatusFound) } func RequestError(w http.ResponseWriter, r *http.Request, err error) { - e, ok := err.(*OAuthError) - if !ok { - e = new(OAuthError) - e.ErrorType = ServerError - e.Description = err.Error() + e := oidc.DefaultToServerError(err, err.Error()) + status := http.StatusBadRequest + if e.ErrorType == oidc.InvalidClient { + status = 401 } - w.WriteHeader(http.StatusBadRequest) - utils.MarshalJSON(w, e) -} - -type OAuthError struct { - ErrorType errorType `json:"error" schema:"error"` - Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` - State string `json:"state,omitempty" schema:"state,omitempty"` - redirectDisabled bool `json:"-" schema:"-"` -} - -func (e *OAuthError) Error() string { - return fmt.Sprintf("%s: %s", e.ErrorType, e.Description) + httphelper.MarshalJSONWithStatus(w, e, status) } diff --git a/pkg/op/keys.go b/pkg/op/keys.go index c4b11d4..e637066 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -6,7 +6,7 @@ import ( "gopkg.in/square/go-jose.v2" - "github.com/caos/oidc/pkg/utils" + httphelper "github.com/caos/oidc/pkg/http" ) type KeyProvider interface { @@ -22,9 +22,8 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { keySet, err := k.GetKeySet(r.Context()) if err != nil { - w.WriteHeader(http.StatusInternalServerError) - utils.MarshalJSON(w, err) + httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError) return } - utils.MarshalJSON(w, keySet) + httphelper.MarshalJSON(w, keySet) } diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go new file mode 100644 index 0000000..bf60a3e --- /dev/null +++ b/pkg/op/keys_test.go @@ -0,0 +1,100 @@ +package op_test + +import ( + "crypto/rsa" + "math/big" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/op/mock" +) + +func TestKeys(t *testing.T) { + type args struct { + k op.KeyProvider + } + type res struct { + statusCode int + contentType string + body string + } + tests := []struct { + name string + args args + res res + }{ + { + name: "error", + args: args{ + k: func() op.KeyProvider { + m := mock.NewMockKeyProvider(gomock.NewController(t)) + m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError()) + return m + }(), + }, + res: res{ + statusCode: http.StatusInternalServerError, + contentType: "application/json", + body: `{"error":"server_error"} +`, + }, + }, + { + name: "empty list", + args: args{ + k: func() op.KeyProvider { + m := mock.NewMockKeyProvider(gomock.NewController(t)) + m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil) + return m + }(), + }, + res: res{ + statusCode: http.StatusOK, + contentType: "application/json", + }, + }, + { + name: "list", + args: args{ + k: func() op.KeyProvider { + m := mock.NewMockKeyProvider(gomock.NewController(t)) + m.EXPECT().GetKeySet(gomock.Any()).Return( + &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ + { + Key: &rsa.PublicKey{ + N: big.NewInt(1), + E: 1, + }, + KeyID: "id", + }, + }}, + nil, + ) + return m + }(), + }, + res: res{ + statusCode: http.StatusOK, + contentType: "application/json", + body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]} +`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + op.Keys(w, httptest.NewRequest("GET", "/keys", nil), tt.args.k) + assert.Equal(t, tt.res.statusCode, w.Result().StatusCode) + assert.Equal(t, tt.res.contentType, w.Header().Get("content-type")) + assert.Equal(t, tt.res.body, w.Body.String()) + }) + } +} diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index 69f6927..3c18022 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -7,8 +7,8 @@ package mock import ( reflect "reflect" + http "github.com/caos/oidc/pkg/http" op "github.com/caos/oidc/pkg/op" - utils "github.com/caos/oidc/pkg/utils" gomock "github.com/golang/mock/gomock" ) @@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call { } // Decoder mocks base method. -func (m *MockAuthorizer) Decoder() utils.Decoder { +func (m *MockAuthorizer) Decoder() http.Decoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Decoder") - ret0, _ := ret[0].(utils.Decoder) + ret0, _ := ret[0].(http.Decoder) return ret0 } @@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call { } // Encoder mocks base method. -func (m *MockAuthorizer) Encoder() utils.Encoder { +func (m *MockAuthorizer) Encoder() http.Encoder { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Encoder") - ret0, _ := ret[0].(utils.Encoder) + ret0, _ := ret[0].(http.Encoder) return ret0 } @@ -105,6 +105,20 @@ func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer)) } +// RequestObjectSupported mocks base method. +func (m *MockAuthorizer) RequestObjectSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequestObjectSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// RequestObjectSupported indicates an expected call of RequestObjectSupported. +func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported)) +} + // Signer mocks base method. func (m *MockAuthorizer) Signer() op.Signer { m.ctrl.T.Helper() diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index 01c2c8d..3eb4542 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -147,6 +147,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported)) } +// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method. +func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IntrospectionAuthMethodPrivateKeyJWTSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IntrospectionAuthMethodPrivateKeyJWTSupported indicates an expected call of IntrospectionAuthMethodPrivateKeyJWTSupported. +func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionAuthMethodPrivateKeyJWTSupported)) +} + // IntrospectionEndpoint mocks base method. func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { m.ctrl.T.Helper() @@ -161,6 +175,20 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpoint() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpoint)) } +// IntrospectionEndpointSigningAlgorithmsSupported mocks base method. +func (m *MockConfiguration) IntrospectionEndpointSigningAlgorithmsSupported() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IntrospectionEndpointSigningAlgorithmsSupported") + ret0, _ := ret[0].([]string) + return ret0 +} + +// IntrospectionEndpointSigningAlgorithmsSupported indicates an expected call of IntrospectionEndpointSigningAlgorithmsSupported. +func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported)) +} + // Issuer mocks base method. func (m *MockConfiguration) Issuer() string { m.ctrl.T.Helper() @@ -189,6 +217,76 @@ func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint)) } +// RequestObjectSigningAlgorithmsSupported mocks base method. +func (m *MockConfiguration) RequestObjectSigningAlgorithmsSupported() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequestObjectSigningAlgorithmsSupported") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RequestObjectSigningAlgorithmsSupported indicates an expected call of RequestObjectSigningAlgorithmsSupported. +func (mr *MockConfigurationMockRecorder) RequestObjectSigningAlgorithmsSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSigningAlgorithmsSupported)) +} + +// RequestObjectSupported mocks base method. +func (m *MockConfiguration) RequestObjectSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequestObjectSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// RequestObjectSupported indicates an expected call of RequestObjectSupported. +func (mr *MockConfigurationMockRecorder) RequestObjectSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSupported)) +} + +// RevocationAuthMethodPrivateKeyJWTSupported mocks base method. +func (m *MockConfiguration) RevocationAuthMethodPrivateKeyJWTSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevocationAuthMethodPrivateKeyJWTSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// RevocationAuthMethodPrivateKeyJWTSupported indicates an expected call of RevocationAuthMethodPrivateKeyJWTSupported. +func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationAuthMethodPrivateKeyJWTSupported)) +} + +// RevocationEndpoint mocks base method. +func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevocationEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// RevocationEndpoint indicates an expected call of RevocationEndpoint. +func (mr *MockConfigurationMockRecorder) RevocationEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpoint)) +} + +// RevocationEndpointSigningAlgorithmsSupported mocks base method. +func (m *MockConfiguration) RevocationEndpointSigningAlgorithmsSupported() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevocationEndpointSigningAlgorithmsSupported") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RevocationEndpointSigningAlgorithmsSupported indicates an expected call of RevocationEndpointSigningAlgorithmsSupported. +func (mr *MockConfigurationMockRecorder) RevocationEndpointSigningAlgorithmsSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpointSigningAlgorithmsSupported)) +} + // SupportedUILocales mocks base method. func (m *MockConfiguration) SupportedUILocales() []language.Tag { m.ctrl.T.Helper() @@ -217,6 +315,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint)) } +// TokenEndpointSigningAlgorithmsSupported mocks base method. +func (m *MockConfiguration) TokenEndpointSigningAlgorithmsSupported() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TokenEndpointSigningAlgorithmsSupported") + ret0, _ := ret[0].([]string) + return ret0 +} + +// TokenEndpointSigningAlgorithmsSupported indicates an expected call of TokenEndpointSigningAlgorithmsSupported. +func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported)) +} + // UserinfoEndpoint mocks base method. func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { m.ctrl.T.Helper() diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go index beb3132..4dd020e 100644 --- a/pkg/op/mock/generate.go +++ b/pkg/op/mock/generate.go @@ -5,3 +5,4 @@ package mock //go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client //go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration //go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer +//go:generate mockgen -package mock -destination ./key.mock.go github.com/caos/oidc/pkg/op KeyProvider diff --git a/pkg/op/mock/key.mock.go b/pkg/op/mock/key.mock.go new file mode 100644 index 0000000..37e0677 --- /dev/null +++ b/pkg/op/mock/key.mock.go @@ -0,0 +1,51 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: KeyProvider) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + jose "gopkg.in/square/go-jose.v2" +) + +// MockKeyProvider is a mock of KeyProvider interface. +type MockKeyProvider struct { + ctrl *gomock.Controller + recorder *MockKeyProviderMockRecorder +} + +// MockKeyProviderMockRecorder is the mock recorder for MockKeyProvider. +type MockKeyProviderMockRecorder struct { + mock *MockKeyProvider +} + +// NewMockKeyProvider creates a new mock instance. +func NewMockKeyProvider(ctrl *gomock.Controller) *MockKeyProvider { + mock := &MockKeyProvider{ctrl: ctrl} + mock.recorder = &MockKeyProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder { + return m.recorder +} + +// GetKeySet mocks base method. +func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKeySet", arg0) + ret0, _ := ret[0].(*jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKeySet indicates an expected call of GetKeySet. +func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0) +} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 4b44f2b..0763230 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -230,6 +230,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) } +// RevokeToken mocks base method. +func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*oidc.Error) + return ret0 +} + +// RevokeToken indicates an expected call of RevokeToken. +func (mr *MockStorageMockRecorder) RevokeToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockStorage)(nil).RevokeToken), arg0, arg1, arg2, arg3) +} + // SaveAuthCode mocks base method. func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() diff --git a/pkg/op/op.go b/pkg/op/op.go index 3841227..8a5be26 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -12,8 +12,8 @@ import ( "golang.org/x/text/language" "gopkg.in/square/go-jose.v2" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) const ( @@ -23,6 +23,7 @@ const ( defaultTokenEndpoint = "oauth/token" defaultIntrospectEndpoint = "oauth/introspect" defaultUserinfoEndpoint = "userinfo" + defaultRevocationEndpoint = "revoke" defaultEndSessionEndpoint = "end_session" defaultKeysEndpoint = "keys" ) @@ -33,6 +34,7 @@ var ( Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), Userinfo: NewEndpoint(defaultUserinfoEndpoint), + Revocation: NewEndpoint(defaultRevocationEndpoint), EndSession: NewEndpoint(defaultEndSessionEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint), } @@ -41,8 +43,8 @@ var ( type OpenIDProvider interface { Configuration Storage() Storage - Decoder() utils.Decoder - Encoder() utils.Encoder + Decoder() httphelper.Decoder + Encoder() httphelper.Encoder IDTokenHintVerifier() IDTokenHintVerifier AccessTokenVerifier() AccessTokenVerifier Crypto() Crypto @@ -74,6 +76,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o))) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) + router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o)) router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) return router @@ -84,8 +87,10 @@ type Config struct { CryptoKey [32]byte DefaultLogoutRedirectURI string CodeMethodS256 bool + AuthMethodPost bool AuthMethodPrivateKeyJWT bool GrantTypeRefreshToken bool + RequestObjectSupported bool SupportedUILocales []language.Tag } @@ -94,6 +99,7 @@ type endpoints struct { Token Endpoint Introspection Endpoint Userinfo Endpoint + Revocation Endpoint EndSession Endpoint CheckSessionIframe Endpoint JwksURI Endpoint @@ -148,7 +154,6 @@ type openidProvider struct { decoder *schema.Decoder encoder *schema.Encoder interceptors []HttpInterceptor - retry func(int) (bool, int) timer <-chan time.Time } @@ -172,6 +177,10 @@ func (o *openidProvider) UserinfoEndpoint() Endpoint { return o.endpoints.Userinfo } +func (o *openidProvider) RevocationEndpoint() Endpoint { + return o.endpoints.Revocation +} + func (o *openidProvider) EndSessionEndpoint() Endpoint { return o.endpoints.EndSession } @@ -181,7 +190,7 @@ func (o *openidProvider) KeysEndpoint() Endpoint { } func (o *openidProvider) AuthMethodPostSupported() bool { - return true //todo: config + return o.config.AuthMethodPost } func (o *openidProvider) CodeMethodS256Supported() bool { @@ -192,6 +201,10 @@ func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool { return o.config.AuthMethodPrivateKeyJWT } +func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string { + return []string{"RS256"} +} + func (o *openidProvider) GrantTypeRefreshTokenSupported() bool { return o.config.GrantTypeRefreshToken } @@ -204,6 +217,30 @@ func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool { return true } +func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool { + return true +} + +func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string { + return []string{"RS256"} +} + +func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool { + return true +} + +func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string { + return []string{"RS256"} +} + +func (o *openidProvider) RequestObjectSupported() bool { + return o.config.RequestObjectSupported +} + +func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string { + return []string{"RS256"} +} + func (o *openidProvider) SupportedUILocales() []language.Tag { return o.config.SupportedUILocales } @@ -212,11 +249,11 @@ func (o *openidProvider) Storage() Storage { return o.storage } -func (o *openidProvider) Decoder() utils.Decoder { +func (o *openidProvider) Decoder() httphelper.Decoder { return o.decoder } -func (o *openidProvider) Encoder() utils.Encoder { +func (o *openidProvider) Encoder() httphelper.Encoder { return o.encoder } @@ -332,6 +369,16 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } +func WithCustomRevocationEndpoint(endpoint Endpoint) Option { + return func(o *openidProvider) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.Revocation = endpoint + return nil + } +} + func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { return func(o *openidProvider) error { if err := endpoint.Validate(); err != nil { @@ -352,11 +399,12 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, endSession, keys Endpoint) Option { +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { return func(o *openidProvider) error { o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo + o.endpoints.Revocation = revocation o.endpoints.EndSession = endSession o.endpoints.JwksURI = keys return nil diff --git a/pkg/op/probes.go b/pkg/op/probes.go index c6bb748..b6fdde2 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -5,7 +5,7 @@ import ( "errors" "net/http" - "github.com/caos/oidc/pkg/utils" + httphelper "github.com/caos/oidc/pkg/http" ) type ProbesFn func(context.Context) error @@ -49,7 +49,7 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - utils.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, status{"ok"}) } type status struct { diff --git a/pkg/op/session.go b/pkg/op/session.go index 4d75098..1f9290e 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -4,12 +4,12 @@ import ( "context" "net/http" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type SessionEnder interface { - Decoder() utils.Decoder + Decoder() httphelper.Decoder Storage() Storage IDTokenHintVerifier() IDTokenHintVerifier DefaultLogoutRedirectURI() string @@ -38,21 +38,21 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID) if err != nil { - RequestError(w, r, ErrServerError("error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) return } http.Redirect(w, r, session.RedirectURI, http.StatusFound) } -func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) { +func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) { err := r.ParseForm() if err != nil { - return nil, ErrInvalidRequest("error parsing form") + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } req := new(oidc.EndSessionRequest) err = decoder.Decode(req, r.Form) if err != nil { - return nil, ErrInvalidRequest("error decoding form") + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } return req, nil } @@ -64,12 +64,12 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, } claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier()) if err != nil { - return nil, ErrInvalidRequest("id_token_hint invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) } session.UserID = claims.GetSubject() session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty()) if err != nil { - return nil, ErrServerError("") + return nil, oidc.DefaultToServerError(err, "") } if req.PostLogoutRedirectURI == "" { session.RedirectURI = ender.DefaultLogoutRedirectURI() @@ -81,5 +81,5 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, return session, nil } } - return nil, ErrInvalidRequest("post_logout_redirect_uri invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid") } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index ca9ae7c..94c2a33 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -20,7 +20,8 @@ type AuthStorage interface { CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (RefreshTokenRequest, error) - TerminateSession(context.Context, string, string) error + TerminateSession(ctx context.Context, userID string, clientID string) error + RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error GetSigningKey(context.Context, chan<- jose.SigningKey) GetKeySet(context.Context) (*jose.JSONWebKeySet, error) diff --git a/pkg/op/token.go b/pkg/op/token.go index a587f8a..3e97360 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -4,8 +4,9 @@ import ( "context" "time" + "github.com/caos/oidc/pkg/crypto" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/strings" ) type TokenCreator interface { @@ -64,7 +65,7 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag func needsRefreshToken(tokenRequest TokenRequest, client Client) bool { switch req := tokenRequest.(type) { case AuthRequest: - return utils.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) + return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) case RefreshTokenRequest: return true default: @@ -104,7 +105,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex } claims.SetPrivateClaims(privateClaims) } - return utils.Sign(claims, signer.Signer()) + return crypto.Sign(claims, signer.Signer()) } type IDTokenRequest interface { @@ -151,7 +152,7 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v claims.SetCodeHash(codeHash) } - return utils.Sign(claims, signer.Signer()) + return crypto.Sign(claims, signer.Signer()) } func removeUserinfoScopes(scopes []string) []string { @@ -167,5 +168,5 @@ func removeUserinfoScopes(scopes []string) []string { newScopeList = append(newScopeList, scope) } } - return scopes + return newScopeList } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index fa941df..7b5873c 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -2,11 +2,10 @@ package op import ( "context" - "errors" "net/http" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) //CodeExchange handles the OAuth 2.0 authorization_code grant, including @@ -17,7 +16,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { RequestError(w, r, err) } if tokenReq.Code == "" { - RequestError(w, r, ErrInvalidRequest("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) @@ -30,11 +29,11 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { RequestError(w, r, err) return } - utils.MarshalJSON(w, resp) + httphelper.MarshalJSON(w, resp) } //ParseAccessTokenRequest parsed the http request into a oidc.AccessTokenRequest -func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) { +func ParseAccessTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AccessTokenRequest, error) { request := new(oidc.AccessTokenRequest) err := ParseAuthenticatedTokenRequest(r, decoder, request) if err != nil { @@ -51,13 +50,13 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR return nil, nil, err } if client.GetID() != authReq.GetClientID() { - return nil, nil, ErrInvalidRequest("invalid auth code") + return nil, nil, oidc.ErrInvalidGrant() } if !ValidateGrantType(client, oidc.GrantTypeCode) { - return nil, nil, ErrInvalidRequest("invalid_grant") + return nil, nil, oidc.ErrUnauthorizedClient() } if tokenReq.RedirectURI != authReq.GetRedirectURI() { - return nil, nil, ErrInvalidRequest("redirect_uri does not correspond") + return nil, nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond") } return authReq, client, nil } @@ -68,7 +67,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { - return nil, nil, errors.New("auth_method private_key_jwt not supported") + return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") } client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger) if err != nil { @@ -79,10 +78,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, } client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) if err != nil { - return nil, nil, err + return nil, nil, oidc.ErrInvalidClient().WithParent(err) } if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { - return nil, nil, errors.New("invalid_grant") + return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") } if client.AuthMethod() == oidc.AuthMethodNone { request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) @@ -93,9 +92,12 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { - return nil, nil, errors.New("auth_method post not supported") + return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") } err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) + if err != nil { + return nil, nil, err + } request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) return request, client, err } @@ -104,7 +106,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) { authReq, err := storage.AuthRequestByCode(ctx, code) if err != nil { - return nil, ErrInvalidRequest("invalid code") + return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err) } return authReq, nil } diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 8d93e0c..501f6e5 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -3,28 +3,9 @@ package op import ( "errors" "net/http" - - "github.com/caos/oidc/pkg/oidc" ) //TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange") func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { - tokenRequest, err := ParseTokenExchangeRequest(w, r) - if err != nil { - RequestError(w, r, err) - return - } - err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage()) - if err != nil { - RequestError(w, r, err) - return - } -} - -func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { - return nil, errors.New("Unimplemented") //TODO: impl -} - -func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error { - return errors.New("Unimplemented") //TODO: impl + RequestError(w, r, errors.New("unimplemented")) } diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index e2ae0ad..8fd9187 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -5,12 +5,12 @@ import ( "net/http" "net/url" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type Introspector interface { - Decoder() utils.Decoder + Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage AccessTokenVerifier() AccessTokenVerifier @@ -36,16 +36,16 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto } tokenID, subject, ok := getTokenIDAndSubject(r.Context(), introspector, token) if !ok { - utils.MarshalJSON(w, response) + httphelper.MarshalJSON(w, response) return } err = introspector.Storage().SetIntrospectionFromToken(r.Context(), response, tokenID, subject, clientID) if err != nil { - utils.MarshalJSON(w, response) + httphelper.MarshalJSON(w, response) return } response.SetActive(true) - utils.MarshalJSON(w, response) + httphelper.MarshalJSON(w, response) } func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) { diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index ac3e2a1..01a1411 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -5,8 +5,8 @@ import ( "net/http" "time" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type JWTAuthorizationGrantExchanger interface { @@ -37,18 +37,18 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati RequestError(w, r, err) return } - utils.MarshalJSON(w, resp) + httphelper.MarshalJSON(w, resp) } -func ParseJWTProfileGrantRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) { +func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) { err := r.ParseForm() if err != nil { - return nil, ErrInvalidRequest("error parsing form") + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } tokenReq := new(oidc.JWTProfileGrantRequest) err = decoder.Decode(tokenReq, r.Form) if err != nil { - return nil, ErrInvalidRequest("error decoding form") + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } return tokenReq, nil } @@ -74,6 +74,6 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea //ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest // //deprecated: use ParseJWTProfileGrantRequest -func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) { +func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) { return ParseJWTProfileGrantRequest(r, decoder) } diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index debcca1..0b6d470 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -6,8 +6,9 @@ import ( "net/http" "time" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" + "github.com/caos/oidc/pkg/strings" ) type RefreshTokenRequest interface { @@ -37,11 +38,11 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch RequestError(w, r, err) return } - utils.MarshalJSON(w, resp) + httphelper.MarshalJSON(w, resp) } //ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest -func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.RefreshTokenRequest, error) { +func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.RefreshTokenRequest, error) { request := new(oidc.RefreshTokenRequest) err := ParseAuthenticatedTokenRequest(r, decoder, request) if err != nil { @@ -54,14 +55,14 @@ func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.Ref //and returns the data representing the original auth request corresponding to the refresh_token func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) { if tokenReq.RefreshToken == "" { - return nil, nil, ErrInvalidRequest("code missing") + return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing") } request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger) if err != nil { return nil, nil, err } if client.GetID() != request.GetClientID() { - return nil, nil, ErrInvalidRequest("invalid auth code") + return nil, nil, oidc.ErrInvalidGrant() } if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil { return nil, nil, err @@ -77,15 +78,15 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok return nil } for _, scope := range requestedScopes { - if !utils.Contains(authRequest.GetScopes(), scope) { - return errors.New("invalid_scope") + if !strings.Contains(authRequest.GetScopes(), scope) { + return oidc.ErrInvalidScope() } } authRequest.SetCurrentScopes(requestedScopes) return nil } -//AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered. +//AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered. //It than returns the data representing the original auth request corresponding to the refresh_token func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) { if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { @@ -98,7 +99,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ return nil, nil, err } if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { - return nil, nil, ErrInvalidRequest("invalid_grant") + return nil, nil, oidc.ErrUnauthorizedClient() } request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) return request, client, err @@ -108,17 +109,17 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ return nil, nil, err } if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { - return nil, nil, ErrInvalidRequest("invalid_grant") + return nil, nil, oidc.ErrUnauthorizedClient() } if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { - return nil, nil, errors.New("invalid_grant") + return nil, nil, oidc.ErrInvalidClient() } if client.AuthMethod() == oidc.AuthMethodNone { request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { - return nil, nil, errors.New("auth_method post not supported") + return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") } if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil { return nil, nil, err @@ -132,7 +133,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken) if err != nil { - return nil, ErrInvalidRequest("invalid refreshToken") + return nil, oidc.ErrInvalidGrant().WithParent(err) } return request, nil } diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index fd26f19..6732bb1 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -5,14 +5,14 @@ import ( "net/http" "net/url" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type Exchanger interface { Issuer() string Storage() Storage - Decoder() utils.Decoder + Decoder() httphelper.Decoder Signer() Signer Crypto() Crypto AuthMethodPostSupported() bool @@ -24,7 +24,8 @@ type Exchanger interface { func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - switch r.FormValue("grant_type") { + grantType := r.FormValue("grant_type") + switch grantType { case string(oidc.GrantTypeCode): CodeExchange(w, r, exchanger) return @@ -44,14 +45,14 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque return } case "": - RequestError(w, r, ErrInvalidRequest("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) return } - RequestError(w, r, ErrInvalidRequest("grant_type not supported")) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) } } -//authenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest +//AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest //it is implemented by oidc.AuthRequest and oidc.RefreshTokenRequest type AuthenticatedTokenRequest interface { SetClientID(string) @@ -60,48 +61,49 @@ type AuthenticatedTokenRequest interface { //ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either //HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface -func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, request AuthenticatedTokenRequest) error { +func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, request AuthenticatedTokenRequest) error { err := r.ParseForm() if err != nil { - return ErrInvalidRequest("error parsing form") + return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } err = decoder.Decode(request, r.Form) if err != nil { - return ErrInvalidRequest("error decoding form") + return oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) } clientID, clientSecret, ok := r.BasicAuth() - if ok { - clientID, err = url.QueryUnescape(clientID) - if err != nil { - return ErrInvalidRequest("invalid basic auth header") - } - clientSecret, err = url.QueryUnescape(clientSecret) - if err != nil { - return ErrInvalidRequest("invalid basic auth header") - } - request.SetClientID(clientID) - request.SetClientSecret(clientSecret) + if !ok { + return nil } + clientID, err = url.QueryUnescape(clientID) + if err != nil { + return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + request.SetClientID(clientID) + request.SetClientSecret(clientSecret) return nil } -//AuthorizeRefreshClientByClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST) +//AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST) func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error { err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) if err != nil { - return err //TODO: wrap? + return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err) } return nil } -//AuthorizeCodeClientByCodeChallenge authorizes a client by validating the code_verifier against the previously sent +//AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent //code_challenge of the auth request (PKCE) func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { if tokenReq.CodeVerifier == "" { - return ErrInvalidRequest("code_challenge required") + return oidc.ErrInvalidRequest().WithDescription("code_challenge required") } if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { - return ErrInvalidRequest("code_challenge invalid") + return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") } return nil } @@ -118,7 +120,7 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang return nil, err } if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT { - return nil, ErrInvalidRequest("invalid_client") + return nil, oidc.ErrInvalidClient() } return client, nil } diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go new file mode 100644 index 0000000..fbaf8b7 --- /dev/null +++ b/pkg/op/token_revocation.go @@ -0,0 +1,136 @@ +package op + +import ( + "context" + "net/http" + "net/url" + "strings" + + httphelper "github.com/caos/oidc/pkg/http" + "github.com/caos/oidc/pkg/oidc" +) + +type Revoker interface { + Decoder() httphelper.Decoder + Crypto() Crypto + Storage() Storage + AccessTokenVerifier() AccessTokenVerifier + AuthMethodPrivateKeyJWTSupported() bool + AuthMethodPostSupported() bool +} + +type RevokerJWTProfile interface { + Revoker + JWTProfileVerifier() JWTProfileVerifier +} + +func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + Revoke(w, r, revoker) + } +} + +func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { + token, _, clientID, err := ParseTokenRevocationRequest(r, revoker) + if err != nil { + RevocationRequestError(w, r, err) + return + } + tokenID, subject, ok := getTokenIDAndSubjectForRevocation(r.Context(), revoker, token) + if ok { + token = tokenID + } + if err := revoker.Storage().RevokeToken(r.Context(), token, subject, clientID); err != nil { + RevocationRequestError(w, r, err) + return + } + httphelper.MarshalJSON(w, nil) +} + +func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) { + err = r.ParseForm() + if err != nil { + return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err) + } + req := new(struct { + oidc.RevocationRequest + oidc.ClientAssertionParams //for auth_method private_key_jwt + ClientID string `schema:"client_id"` //for auth_method none and post + ClientSecret string `schema:"client_secret"` //for auth_method post + }) + err = revoker.Decoder().Decode(req, r.Form) + if err != nil { + return "", "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + if req.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + revokerJWTProfile, ok := revoker.(RevokerJWTProfile) + if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() { + return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier()) + if err == nil { + return req.Token, req.TokenTypeHint, profile.Issuer, nil + } + return "", "", "", err + } + clientID, clientSecret, ok := r.BasicAuth() + if ok { + clientID, err = url.QueryUnescape(clientID) + if err != nil { + return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + clientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + if err = AuthorizeClientIDSecret(r.Context(), clientID, clientSecret, revoker.Storage()); err != nil { + return "", "", "", err + } + return req.Token, req.TokenTypeHint, clientID, nil + } + if req.ClientID == "" { + return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization") + } + client, err := revoker.Storage().GetClientByClientID(r.Context(), req.ClientID) + if err != nil { + return "", "", "", oidc.ErrInvalidClient().WithParent(err) + } + if req.ClientSecret == "" { + if client.AuthMethod() != oidc.AuthMethodNone { + return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization") + } + return req.Token, req.TokenTypeHint, req.ClientID, nil + } + if client.AuthMethod() == oidc.AuthMethodPost && !revoker.AuthMethodPostSupported() { + return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + if err = AuthorizeClientIDSecret(r.Context(), req.ClientID, req.ClientSecret, revoker.Storage()); err != nil { + return "", "", "", err + } + return req.Token, req.TokenTypeHint, req.ClientID, nil +} + +func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + e := oidc.DefaultToServerError(err, err.Error()) + status := http.StatusBadRequest + if e.ErrorType == oidc.InvalidClient { + status = 401 + } + httphelper.MarshalJSONWithStatus(w, e, status) +} + +func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) + if err == nil { + splitToken := strings.Split(tokenIDSubject, ":") + if len(splitToken) != 2 { + return "", "", false + } + return splitToken[0], splitToken[1], true + } + accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier()) + if err != nil { + return "", "", false + } + return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true +} diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 9abf378..f07a8bc 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -6,12 +6,12 @@ import ( "net/http" "strings" + httphelper "github.com/caos/oidc/pkg/http" "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/utils" ) type UserinfoProvider interface { - Decoder() utils.Decoder + Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage AccessTokenVerifier() AccessTokenVerifier @@ -37,14 +37,13 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP info := oidc.NewUserInfo() err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin")) if err != nil { - w.WriteHeader(http.StatusForbidden) - utils.MarshalJSON(w, err) + httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden) return } - utils.MarshalJSON(w, info) + httphelper.MarshalJSON(w, info) } -func ParseUserinfoRequest(r *http.Request, decoder utils.Decoder) (string, error) { +func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) { accessToken, err := getAccessToken(r) if err == nil { return accessToken, nil diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 05168a6..2220244 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -49,7 +49,7 @@ func (i *accessTokenVerifier) KeySet() oidc.KeySet { } func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier { - verifier := &idTokenHintVerifier{ + verifier := &accessTokenVerifier{ issuer: issuer, keySet: keySet, } diff --git a/pkg/utils/strings.go b/pkg/strings/strings.go similarity index 89% rename from pkg/utils/strings.go rename to pkg/strings/strings.go index 5ffcd37..af48cf3 100644 --- a/pkg/utils/strings.go +++ b/pkg/strings/strings.go @@ -1,4 +1,4 @@ -package utils +package strings func Contains(list []string, needle string) bool { for _, item := range list { diff --git a/pkg/utils/strings_test.go b/pkg/strings/strings_test.go similarity index 98% rename from pkg/utils/strings_test.go rename to pkg/strings/strings_test.go index 86af2af..78698d4 100644 --- a/pkg/utils/strings_test.go +++ b/pkg/strings/strings_test.go @@ -1,4 +1,4 @@ -package utils +package strings import "testing"