feat: Token Revocation, Request Object and OP Certification (#130)

FEATURES (and FIXES):
- support OAuth 2.0 Token Revocation [RFC 7009](https://datatracker.ietf.org/doc/html/rfc7009)
- handle request object using `request` parameter [OIDC Core 1.0 Request Object](https://openid.net/specs/openid-connect-core-1_0.html#RequestObject)
- handle response mode
- added some information to the discovery endpoint:
  - revocation_endpoint (added with token revocation) 
  - revocation_endpoint_auth_methods_supported (added with token revocation)
  - revocation_endpoint_auth_signing_alg_values_supported (added with token revocation)
  - token_endpoint_auth_signing_alg_values_supported (was missing)
  - introspection_endpoint_auth_signing_alg_values_supported (was missing)
  - request_object_signing_alg_values_supported (added with request object)
  - request_parameter_supported (added with request object)
 - fixed `removeUserinfoScopes ` now returns the scopes without "userinfo" scopes (profile, email, phone, addedd) [source diff](https://github.com/caos/oidc/pull/130/files#diff-fad50c8c0f065d4dbc49d6c6a38f09c992c8f5d651a479ba00e31b500543559eL170-R171)
- improved error handling (pkg/oidc/error.go) and fixed some wrong OAuth errors (e.g. `invalid_grant` instead of `invalid_request`)
- improved MarshalJSON and added MarshalJSONWithStatus
- removed deprecated PEM decryption from `BytesToPrivateKey`  [source diff](https://github.com/caos/oidc/pull/130/files#diff-fe246e428e399ccff599627c71764de51387b60b4df84c67de3febd0954e859bL11-L19)
- NewAccessTokenVerifier now uses correct (internal) `accessTokenVerifier` [source diff](https://github.com/caos/oidc/pull/130/files#diff-3a01c7500ead8f35448456ef231c7c22f8d291710936cac91de5edeef52ffc72L52-R52)

BREAKING CHANGE:
- move functions from `utils` package into separate packages
- added various methods to the (OP) `Configuration` interface [source diff](https://github.com/caos/oidc/pull/130/files#diff-2538e0dfc772fdc37f057aecd6fcc2943f516c24e8be794cce0e368a26d20a82R19-R32)
- added revocationEndpoint to `WithCustomEndpoints ` [source diff](https://github.com/caos/oidc/pull/130/files#diff-19ae13a743eb7cebbb96492798b1bec556673eb6236b1387e38d722900bae1c3L355-R391)
- remove unnecessary context parameter from JWTProfileExchange [source diff](https://github.com/caos/oidc/pull/130/files#diff-4ed8f6affa4a9631fa8a034b3d5752fbb6a819107141aae00029014e950f7b4cL14)
This commit is contained in:
Livio Amstutz 2021-11-02 13:21:35 +01:00 committed by GitHub
parent 763d3334e7
commit eb10752e48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
63 changed files with 1738 additions and 624 deletions

View file

@ -12,13 +12,13 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/caos/oidc/pkg/client/rp" "github.com/caos/oidc/pkg/client/rp"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
var ( var (
callbackPath string = "/auth/callback" callbackPath = "/auth/callback"
key []byte = []byte("test1234test1234") key = []byte("test1234test1234")
) )
func main() { func main() {
@ -30,7 +30,7 @@ func main() {
scopes := strings.Split(os.Getenv("SCOPES"), " ") scopes := strings.Split(os.Getenv("SCOPES"), " ")
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
options := []rp.Option{ options := []rp.Option{
rp.WithCookieHandler(cookieHandler), rp.WithCookieHandler(cookieHandler),

View file

@ -12,12 +12,12 @@ import (
"github.com/caos/oidc/pkg/client/rp" "github.com/caos/oidc/pkg/client/rp"
"github.com/caos/oidc/pkg/client/rp/cli" "github.com/caos/oidc/pkg/client/rp/cli"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/http"
) )
var ( var (
callbackPath string = "/orbctl/github/callback" callbackPath = "/orbctl/github/callback"
key []byte = []byte("test1234test1234") key = []byte("test1234test1234")
) )
func main() { func main() {
@ -34,7 +34,7 @@ func main() {
} }
ctx := context.Background() ctx := context.Background()
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) cookieHandler := http.NewCookieHandler(key, key, http.WithUnsecure())
relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, rp.WithCookieHandler(cookieHandler)) relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, rp.WithCookieHandler(cookieHandler))
if err != nil { if err != nil {
fmt.Printf("error creating relaying party: %v", err) fmt.Printf("error creating relaying party: %v", err)

View file

@ -17,7 +17,7 @@ import (
) )
var ( var (
client *http.Client = http.DefaultClient client = http.DefaultClient
) )
func main() { func main() {

View file

@ -32,6 +32,7 @@ func NewAuthStorage() op.Storage {
type AuthRequest struct { type AuthRequest struct {
ID string ID string
ResponseType oidc.ResponseType ResponseType oidc.ResponseType
ResponseMode oidc.ResponseMode
RedirectURI string RedirectURI string
Nonce string Nonce string
ClientID string ClientID string
@ -88,6 +89,10 @@ func (a *AuthRequest) GetResponseType() oidc.ResponseType {
return a.ResponseType return a.ResponseType
} }
func (a *AuthRequest) GetResponseMode() oidc.ResponseMode {
return a.ResponseMode
}
func (a *AuthRequest) GetScopes() []string { func (a *AuthRequest) GetScopes() []string {
return []string{ return []string{
"openid", "openid",
@ -170,6 +175,11 @@ func (s *AuthStorage) TokenRequestByRefreshToken(ctx context.Context, refreshTok
func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error { func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error {
return nil return nil
} }
func (s *AuthStorage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error {
return nil
}
func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) { func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) {
keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key} keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key}
} }
@ -289,7 +299,7 @@ func (c *ConfClient) AuthMethod() oidc.AuthMethod {
} }
func (c *ConfClient) IDTokenLifetime() time.Duration { func (c *ConfClient) IDTokenLifetime() time.Duration {
return time.Duration(5 * time.Minute) return 5 * time.Minute
} }
func (c *ConfClient) AccessTokenType() op.AccessTokenType { func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType return c.accessTokenType

View file

@ -10,12 +10,13 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/crypto"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
var ( var (
Encoder = func() utils.Encoder { Encoder = func() httphelper.Encoder {
e := schema.NewEncoder() e := schema.NewEncoder()
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string { e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
return value.Interface().(oidc.SpaceDelimitedArray).Encode() return value.Interface().(oidc.SpaceDelimitedArray).Encode()
@ -32,7 +33,7 @@ func Discover(issuer string, httpClient *http.Client) (*oidc.DiscoveryConfigurat
return nil, err return nil, err
} }
discoveryConfig := new(oidc.DiscoveryConfiguration) discoveryConfig := new(oidc.DiscoveryConfiguration)
err = utils.HttpRequest(httpClient, req, &discoveryConfig) err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -52,12 +53,12 @@ func CallTokenEndpoint(request interface{}, caller tokenEndpointCaller) (newToke
} }
func callTokenEndpoint(request interface{}, authFn interface{}, caller tokenEndpointCaller) (newToken *oauth2.Token, err error) { func callTokenEndpoint(request interface{}, authFn interface{}, caller tokenEndpointCaller) (newToken *oauth2.Token, err error) {
req, err := utils.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenRes := new(oidc.AccessTokenResponse) tokenRes := new(oidc.AccessTokenResponse)
if err := utils.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
return nil, err return nil, err
} }
return &oauth2.Token{ return &oauth2.Token{
@ -69,7 +70,7 @@ func callTokenEndpoint(request interface{}, authFn interface{}, caller tokenEndp
} }
func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) { func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
privateKey, err := utils.BytesToPrivateKey(key) privateKey, err := crypto.BytesToPrivateKey(key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,7 +84,7 @@ func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error)
func SignedJWTProfileAssertion(clientID string, audience []string, expiration time.Duration, signer jose.Signer) (string, error) { func SignedJWTProfileAssertion(clientID string, audience []string, expiration time.Duration, signer jose.Signer) (string, error) {
iat := time.Now() iat := time.Now()
exp := iat.Add(expiration) exp := iat.Add(expiration)
return utils.Sign(&oidc.JWTTokenRequest{ return crypto.Sign(&oidc.JWTTokenRequest{
Issuer: clientID, Issuer: clientID,
Subject: clientID, Subject: clientID,
Audience: audience, Audience: audience,

View file

@ -1,17 +1,16 @@
package client package client
import ( import (
"context"
"net/url" "net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
//JWTProfileExchange handles the oauth2 jwt profile exchange //JWTProfileExchange handles the oauth2 jwt profile exchange
func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller tokenEndpointCaller) (*oauth2.Token, error) { func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller tokenEndpointCaller) (*oauth2.Token, error) {
return CallTokenEndpoint(jwtProfileGrantRequest, caller) return CallTokenEndpoint(jwtProfileGrantRequest, caller)
} }
@ -22,7 +21,7 @@ func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {
} }
} }
func ClientAssertionFormAuthorization(assertion string) utils.FormAuthorization { func ClientAssertionFormAuthorization(assertion string) http.FormAuthorization {
return func(values url.Values) { return func(values url.Values) {
values.Set("client_assertion", assertion) values.Set("client_assertion", assertion)
values.Set("client_assertion_type", oidc.ClientAssertionTypeJWTAssertion) values.Set("client_assertion_type", oidc.ClientAssertionTypeJWTAssertion)

View file

@ -89,5 +89,5 @@ func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.JWTProfileExchange(nil, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
} }

View file

@ -1,4 +1,4 @@
package utils package cli
import ( import (
"fmt" "fmt"

View file

@ -5,8 +5,8 @@ import (
"net/http" "net/http"
"github.com/caos/oidc/pkg/client/rp" "github.com/caos/oidc/pkg/client/rp"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
const ( const (
@ -28,9 +28,9 @@ func CodeFlow(ctx context.Context, relyingParty rp.RelyingParty, callbackPath, p
http.Handle(loginPath, rp.AuthURLHandler(stateProvider, relyingParty)) http.Handle(loginPath, rp.AuthURLHandler(stateProvider, relyingParty))
http.Handle(callbackPath, rp.CodeExchangeHandler(callback, relyingParty)) http.Handle(callbackPath, rp.CodeExchangeHandler(callback, relyingParty))
utils.StartServer(codeflowCtx, ":"+port) httphelper.StartServer(codeflowCtx, ":"+port)
utils.OpenBrowser("http://localhost:" + port + loginPath) OpenBrowser("http://localhost:" + port + loginPath)
return <-tokenChan return <-tokenChan
} }

View file

@ -5,8 +5,8 @@ import (
) )
//DelegationTokenRequest is an implementation of TokenExchangeRequest //DelegationTokenRequest is an implementation of TokenExchangeRequest
//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional //it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
//"urn:ietf:params:oauth:token-type:access_token" actor token for a //"urn:ietf:params:oauth:token-type:access_token" actor token for an
//"urn:ietf:params:oauth:token-type:access_token" delegation token //"urn:ietf:params:oauth:token-type:access_token" delegation token
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest { func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...) return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)

View file

@ -7,9 +7,9 @@ import (
"net/http" "net/http"
"sync" "sync"
"github.com/caos/oidc/pkg/utils"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
) )
@ -207,7 +207,7 @@ func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey,
} }
keySet := new(jsonWebKeySet) keySet := new(jsonWebKeySet)
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil { if err = httphelper.HttpRequest(r.httpClient, req, keySet); err != nil {
return nil, fmt.Errorf("oidc: failed to get keys: %v", err) return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
} }
return keySet.Keys, nil return keySet.Keys, nil

View file

@ -13,8 +13,8 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/client" "github.com/caos/oidc/pkg/client"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
const ( const (
@ -39,7 +39,7 @@ type RelyingParty interface {
IsPKCE() bool IsPKCE() bool
//CookieHandler returns a http cookie handler used for various state transfer cookies //CookieHandler returns a http cookie handler used for various state transfer cookies
CookieHandler() *utils.CookieHandler CookieHandler() *httphelper.CookieHandler
//HttpClient returns a http client used for calls to the openid provider, e.g. calling token endpoint //HttpClient returns a http client used for calls to the openid provider, e.g. calling token endpoint
HttpClient() *http.Client HttpClient() *http.Client
@ -76,7 +76,7 @@ type relyingParty struct {
pkce bool pkce bool
httpClient *http.Client httpClient *http.Client
cookieHandler *utils.CookieHandler cookieHandler *httphelper.CookieHandler
errorHandler func(http.ResponseWriter, *http.Request, string, string, string) errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
idTokenVerifier IDTokenVerifier idTokenVerifier IDTokenVerifier
@ -96,7 +96,7 @@ func (rp *relyingParty) IsPKCE() bool {
return rp.pkce return rp.pkce
} }
func (rp *relyingParty) CookieHandler() *utils.CookieHandler { func (rp *relyingParty) CookieHandler() *httphelper.CookieHandler {
return rp.cookieHandler return rp.cookieHandler
} }
@ -136,7 +136,7 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) { func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) {
rp := &relyingParty{ rp := &relyingParty{
oauthConfig: config, oauthConfig: config,
httpClient: utils.DefaultHTTPClient, httpClient: httphelper.DefaultHTTPClient,
oauth2Only: true, oauth2Only: true,
} }
@ -161,7 +161,7 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
RedirectURL: redirectURI, RedirectURL: redirectURI,
Scopes: scopes, Scopes: scopes,
}, },
httpClient: utils.DefaultHTTPClient, httpClient: httphelper.DefaultHTTPClient,
oauth2Only: false, oauth2Only: false,
} }
@ -181,11 +181,11 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco
return rp, nil return rp, nil
} }
//Option is the type for providing dynamic options to the DefaultRP //Option is the type for providing dynamic options to the relyingParty
type Option func(*relyingParty) error type Option func(*relyingParty) error
//WithCookieHandler set a `CookieHandler` for securing the various redirects //WithCookieHandler set a `CookieHandler` for securing the various redirects
func WithCookieHandler(cookieHandler *utils.CookieHandler) Option { func WithCookieHandler(cookieHandler *httphelper.CookieHandler) Option {
return func(rp *relyingParty) error { return func(rp *relyingParty) error {
rp.cookieHandler = cookieHandler rp.cookieHandler = cookieHandler
return nil return nil
@ -195,7 +195,7 @@ func WithCookieHandler(cookieHandler *utils.CookieHandler) Option {
//WithPKCE sets the RP to use PKCE (oauth2 code challenge) //WithPKCE sets the RP to use PKCE (oauth2 code challenge)
//it also sets a `CookieHandler` for securing the various redirects //it also sets a `CookieHandler` for securing the various redirects
//and exchanging the code challenge //and exchanging the code challenge
func WithPKCE(cookieHandler *utils.CookieHandler) Option { func WithPKCE(cookieHandler *httphelper.CookieHandler) Option {
return func(rp *relyingParty) error { return func(rp *relyingParty) error {
rp.pkce = true rp.pkce = true
rp.cookieHandler = cookieHandler rp.cookieHandler = cookieHandler
@ -246,7 +246,7 @@ func Discover(issuer string, httpClient *http.Client) (Endpoints, error) {
return Endpoints{}, err return Endpoints{}, err
} }
discoveryConfig := new(oidc.DiscoveryConfiguration) discoveryConfig := new(oidc.DiscoveryConfiguration)
err = utils.HttpRequest(httpClient, req, &discoveryConfig) err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
if err != nil { if err != nil {
return Endpoints{}, err return Endpoints{}, err
} }
@ -395,7 +395,7 @@ func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo,
} }
req.Header.Set("authorization", tokenType+" "+token) req.Header.Set("authorization", tokenType+" "+token)
userinfo := oidc.NewUserInfo() userinfo := oidc.NewUserInfo()
if err := utils.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err return nil, err
} }
if userinfo.GetSubject() != subject { if userinfo.GetSubject() != subject {

View file

@ -7,8 +7,8 @@ import (
"time" "time"
"github.com/caos/oidc/pkg/client" "github.com/caos/oidc/pkg/client"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type ResourceServer interface { type ResourceServer interface {
@ -39,7 +39,7 @@ func (r *resourceServer) AuthFn() (interface{}, error) {
func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
authorizer := func() (interface{}, error) { authorizer := func() (interface{}, error) {
return utils.AuthorizeBasic(clientID, clientSecret), nil return httphelper.AuthorizeBasic(clientID, clientSecret), nil
} }
return newResourceServer(issuer, authorizer, option...) return newResourceServer(issuer, authorizer, option...)
} }
@ -61,7 +61,7 @@ func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, opt
func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) {
rs := &resourceServer{ rs := &resourceServer{
issuer: issuer, issuer: issuer,
httpClient: utils.DefaultHTTPClient, httpClient: httphelper.DefaultHTTPClient,
} }
for _, optFunc := range options { for _, optFunc := range options {
optFunc(rs) optFunc(rs)
@ -111,12 +111,12 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err := utils.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp := oidc.NewIntrospectionResponse() resp := oidc.NewIntrospectionResponse()
if err := utils.HttpRequest(rp.HttpClient(), req, resp); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
return nil, err return nil, err
} }
return resp, nil return resp, nil

View file

@ -1,4 +1,4 @@
package utils package crypto
import ( import (
"crypto/aes" "crypto/aes"
@ -9,6 +9,10 @@ import (
"io" "io"
) )
var (
ErrCipherTextBlockSize = errors.New("ciphertext block size is too short")
)
func EncryptAES(data string, key string) (string, error) { func EncryptAES(data string, key string) (string, error) {
encrypted, err := EncryptBytesAES([]byte(data), key) encrypted, err := EncryptBytesAES([]byte(data), key)
if err != nil { if err != nil {
@ -55,8 +59,7 @@ func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) {
} }
if len(cipherText) < aes.BlockSize { if len(cipherText) < aes.BlockSize {
err = errors.New("Ciphertext block size is too short!") return nil, ErrCipherTextBlockSize
return nil, err
} }
iv := cipherText[:aes.BlockSize] iv := cipherText[:aes.BlockSize]
cipherText = cipherText[aes.BlockSize:] cipherText = cipherText[aes.BlockSize:]

View file

@ -1,15 +1,20 @@
package utils package crypto
import ( import (
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"hash" "hash"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
var (
ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")
)
func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
switch sigAlgorithm { switch sigAlgorithm {
case jose.RS256, jose.ES256, jose.PS256: case jose.RS256, jose.ES256, jose.PS256:
@ -19,7 +24,7 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
case jose.RS512, jose.ES512, jose.PS512: case jose.RS512, jose.ES512, jose.PS512:
return sha512.New(), nil return sha512.New(), nil
default: default:
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) return nil, fmt.Errorf("%w: %q", ErrUnsupportedAlgorithm, sigAlgorithm)
} }
} }

View file

@ -1,4 +1,4 @@
package utils package crypto
import ( import (
"crypto/rsa" "crypto/rsa"
@ -8,15 +8,7 @@ import (
func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(priv) block, _ := pem.Decode(priv)
enc := x509.IsEncryptedPEMBlock(block)
b := block.Bytes b := block.Bytes
var err error
if enc {
b, err = x509.DecryptPEMBlock(block, nil)
if err != nil {
return nil, err
}
}
key, err := x509.ParsePKCS1PrivateKey(b) key, err := x509.ParsePKCS1PrivateKey(b)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -1,4 +1,4 @@
package utils package crypto
import ( import (
"encoding/json" "encoding/json"

View file

@ -1,4 +1,4 @@
package utils package http
import ( import (
"errors" "errors"

View file

@ -1,4 +1,4 @@
package utils package http
import ( import (
"context" "context"
@ -14,7 +14,7 @@ import (
var ( var (
DefaultHTTPClient = &http.Client{ DefaultHTTPClient = &http.Client{
Timeout: time.Duration(30 * time.Second), Timeout: 30 * time.Second,
} }
) )

View file

@ -1,24 +1,26 @@
package utils package http
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"github.com/sirupsen/logrus"
) )
func MarshalJSON(w http.ResponseWriter, i interface{}) { func MarshalJSON(w http.ResponseWriter, i interface{}) {
b, err := json.Marshal(i) MarshalJSONWithStatus(w, i, http.StatusOK)
if err != nil { }
http.Error(w, err.Error(), http.StatusInternalServerError)
func MarshalJSONWithStatus(w http.ResponseWriter, i interface{}, status int) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(status)
if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) {
return return
} }
w.Header().Set("content-type", "application/json") err := json.NewEncoder(w).Encode(i)
_, err = w.Write(b)
if err != nil { if err != nil {
logrus.Error("error writing response") http.Error(w, err.Error(), http.StatusInternalServerError)
} }
} }

View file

@ -1,8 +1,11 @@
package utils package http
import ( import (
"bytes" "bytes"
"net/http/httptest"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestConcatenateJSON(t *testing.T) { func TestConcatenateJSON(t *testing.T) {
@ -88,3 +91,66 @@ func TestConcatenateJSON(t *testing.T) {
}) })
} }
} }
func TestMarshalJSONWithStatus(t *testing.T) {
type args struct {
i interface{}
status int
}
type res struct {
statusCode int
body string
}
tests := []struct {
name string
args args
res res
}{
{
"empty ok",
args{
nil,
200,
},
res{
200,
"",
},
},
{
"string ok",
args{
"ok",
200,
},
res{
200,
`"ok"
`,
},
},
{
"struct ok",
args{
struct {
Test string `json:"test"`
}{"ok"},
200,
},
res{
200,
`{"test":"ok"}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
MarshalJSONWithStatus(w, tt.args.i, tt.args.status)
assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
assert.Equal(t, "application/json", w.Header().Get("content-type"))
assert.Equal(t, tt.res.body, w.Body.String())
})
}
}

View file

@ -42,6 +42,9 @@ const (
DisplayTouch Display = "touch" DisplayTouch Display = "touch"
DisplayWAP Display = "wap" DisplayWAP Display = "wap"
ResponseModeQuery ResponseMode = "query"
ResponseModeFragment ResponseMode = "fragment"
//PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages. //PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages.
//An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed //An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed
PromptNone = "none" PromptNone = "none"
@ -59,27 +62,28 @@ const (
//AuthRequest according to: //AuthRequest according to:
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest //https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct { type AuthRequest struct {
ID string Scopes SpaceDelimitedArray `json:"scope" schema:"scope"`
Scopes SpaceDelimitedArray `schema:"scope"` ResponseType ResponseType `json:"response_type" schema:"response_type"`
ResponseType ResponseType `schema:"response_type"` ClientID string `json:"client_id" schema:"client_id"`
ClientID string `schema:"client_id"` RedirectURI string `json:"redirect_uri" schema:"redirect_uri"`
RedirectURI string `schema:"redirect_uri"` //TODO: type
State string `schema:"state"` State string `json:"state" schema:"state"`
Nonce string `json:"nonce" schema:"nonce"`
// ResponseMode TODO: ? ResponseMode ResponseMode `json:"response_mode" schema:"response_mode"`
Display Display `json:"display" schema:"display"`
Prompt SpaceDelimitedArray `json:"prompt" schema:"prompt"`
MaxAge *uint `json:"max_age" schema:"max_age"`
UILocales Locales `json:"ui_locales" schema:"ui_locales"`
IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"`
LoginHint string `json:"login_hint" schema:"login_hint"`
ACRValues []string `json:"acr_values" schema:"acr_values"`
Nonce string `schema:"nonce"` CodeChallenge string `json:"code_challenge" schema:"code_challenge"`
Display Display `schema:"display"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"`
Prompt SpaceDelimitedArray `schema:"prompt"`
MaxAge *uint `schema:"max_age"`
UILocales Locales `schema:"ui_locales"`
IDTokenHint string `schema:"id_token_hint"`
LoginHint string `schema:"login_hint"`
ACRValues []string `schema:"acr_values"`
CodeChallenge string `schema:"code_challenge"` //RequestParam enables OIDC requests to be passed in a single, self-contained parameter (as JWT, called Request Object)
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` RequestParam string `schema:"request"`
} }
//GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface //GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface

View file

@ -3,7 +3,7 @@ package oidc
import ( import (
"crypto/sha256" "crypto/sha256"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/crypto"
) )
const ( const (
@ -19,7 +19,7 @@ type CodeChallenge struct {
} }
func NewSHACodeChallenge(code string) string { func NewSHACodeChallenge(code string) string {
return utils.HashString(sha256.New(), code, false) return crypto.HashString(sha256.New(), code, false)
} }
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool { func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {

View file

@ -9,48 +9,143 @@ const (
) )
type DiscoveryConfiguration struct { type DiscoveryConfiguration struct {
//Issuer is the identifier of the OP and is used in the tokens as `iss` claim.
Issuer string `json:"issuer,omitempty"` Issuer string `json:"issuer,omitempty"`
//AuthorizationEndpoint is the URL of the OAuth 2.0 Authorization Endpoint where all user interactive login start
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
//TokenEndpoint is the URL of the OAuth 2.0 Token Endpoint where all tokens are issued, except when using Implicit Flow
TokenEndpoint string `json:"token_endpoint,omitempty"` TokenEndpoint string `json:"token_endpoint,omitempty"`
//IntrospectionEndpoint is the URL of the OAuth 2.0 Introspection Endpoint.
IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
//UserinfoEndpoint is the URL where an access_token can be used to retrieve the Userinfo.
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
//RevocationEndpoint is the URL of the OAuth 2.0 Revocation Endpoint.
RevocationEndpoint string `json:"revocation_endpoint,omitempty"` RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
//EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
//CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
CheckSessionIframe string `json:"check_session_iframe,omitempty"` CheckSessionIframe string `json:"check_session_iframe,omitempty"`
//JwksURI is the URL of the JSON Web Key Set. This site contains the signing keys that RPs can use to validate the signature.
//It may also contain the OP's encryption keys that RPs can use to encrypt request to the OP.
JwksURI string `json:"jwks_uri,omitempty"` JwksURI string `json:"jwks_uri,omitempty"`
//RegistrationEndpoint is the URL for the Dynamic Client Registration.
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
//ScopesSupported lists an array of supported scopes. This list must not include every supported scope by the OP.
ScopesSupported []string `json:"scopes_supported,omitempty"` ScopesSupported []string `json:"scopes_supported,omitempty"`
//ResponseTypesSupported contains a list of the OAuth 2.0 response_type values that the OP supports (code, id_token, token id_token, ...).
ResponseTypesSupported []string `json:"response_types_supported,omitempty"` ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
//ResponseModesSupported contains a list of the OAuth 2.0 response_mode values that the OP supports. If omitted, the default value is ["query", "fragment"].
ResponseModesSupported []string `json:"response_modes_supported,omitempty"` ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
//GrantTypesSupported contains a list of the OAuth 2.0 grant_type values that the OP supports. If omitted, the default value is ["authorization_code", "implicit"].
GrantTypesSupported []GrantType `json:"grant_types_supported,omitempty"` GrantTypesSupported []GrantType `json:"grant_types_supported,omitempty"`
//ACRValuesSupported contains a list of Authentication Context Class References that the OP supports.
ACRValuesSupported []string `json:"acr_values_supported,omitempty"` ACRValuesSupported []string `json:"acr_values_supported,omitempty"`
//SubjectTypesSupported contains a list of Subject Identifier types that the OP supports (pairwise, public).
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
//IDTokenSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for the ID Token.
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
//IDTokenEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the ID Token.
IDTokenEncryptionAlgValuesSupported []string `json:"id_token_encryption_alg_values_supported,omitempty"` IDTokenEncryptionAlgValuesSupported []string `json:"id_token_encryption_alg_values_supported,omitempty"`
//IDTokenEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the ID Token.
IDTokenEncryptionEncValuesSupported []string `json:"id_token_encryption_enc_values_supported,omitempty"` IDTokenEncryptionEncValuesSupported []string `json:"id_token_encryption_enc_values_supported,omitempty"`
//UserinfoSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for UserInfo Endpoint.
UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported,omitempty"` UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported,omitempty"`
//UserinfoEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the UserInfo Endpoint.
UserinfoEncryptionAlgValuesSupported []string `json:"userinfo_encryption_alg_values_supported,omitempty"` UserinfoEncryptionAlgValuesSupported []string `json:"userinfo_encryption_alg_values_supported,omitempty"`
//UserinfoEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the UserInfo Endpoint.
UserinfoEncryptionEncValuesSupported []string `json:"userinfo_encryption_enc_values_supported,omitempty"` UserinfoEncryptionEncValuesSupported []string `json:"userinfo_encryption_enc_values_supported,omitempty"`
//RequestObjectSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for Request Objects.
//These algorithms are used both then the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter).
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported,omitempty"` RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported,omitempty"`
//RequestObjectEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for Request Objects.
//These algorithms are used both when the Request Object is passed by value and by reference.
RequestObjectEncryptionAlgValuesSupported []string `json:"request_object_encryption_alg_values_supported,omitempty"` RequestObjectEncryptionAlgValuesSupported []string `json:"request_object_encryption_alg_values_supported,omitempty"`
//RequestObjectEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for Request Objects.
//These algorithms are used both when the Request Object is passed by value and by reference.
RequestObjectEncryptionEncValuesSupported []string `json:"request_object_encryption_enc_values_supported,omitempty"` RequestObjectEncryptionEncValuesSupported []string `json:"request_object_encryption_enc_values_supported,omitempty"`
//TokenEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Token Endpoint. If omitted, the default is client_secret_basic.
TokenEndpointAuthMethodsSupported []AuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"` TokenEndpointAuthMethodsSupported []AuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"`
//TokenEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Token Endpoint
//for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"`
//RevocationEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Revocation Endpoint. If omitted, the default is client_secret_basic.
RevocationEndpointAuthMethodsSupported []AuthMethod `json:"revocation_endpoint_auth_methods_supported,omitempty"` RevocationEndpointAuthMethodsSupported []AuthMethod `json:"revocation_endpoint_auth_methods_supported,omitempty"`
//RevocationEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint
//for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"` RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"`
//IntrospectionEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Introspection Endpoint.
IntrospectionEndpointAuthMethodsSupported []AuthMethod `json:"introspection_endpoint_auth_methods_supported,omitempty"` IntrospectionEndpointAuthMethodsSupported []AuthMethod `json:"introspection_endpoint_auth_methods_supported,omitempty"`
//IntrospectionEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint
//for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"` IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"`
//DisplayValuesSupported contains a list of display parameter values that the OP supports (page, popup, touch, wap).
DisplayValuesSupported []Display `json:"display_values_supported,omitempty"` DisplayValuesSupported []Display `json:"display_values_supported,omitempty"`
//ClaimTypesSupported contains a list of Claim Types that the OP supports (normal, aggregated, distributed). If omitted, the default is normal Claims.
ClaimTypesSupported []string `json:"claim_types_supported,omitempty"` ClaimTypesSupported []string `json:"claim_types_supported,omitempty"`
//ClaimsSupported contains a list of Claim Names the OP may be able to supply values for. This list might not be exhaustive.
ClaimsSupported []string `json:"claims_supported,omitempty"` ClaimsSupported []string `json:"claims_supported,omitempty"`
//ClaimsParameterSupported specifies whether the OP supports use of the `claims` parameter. If omitted, the default is false.
ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"` ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"`
//CodeChallengeMethodsSupported contains a list of Proof Key for Code Exchange (PKCE) code challenge methods supported by the OP.
CodeChallengeMethodsSupported []CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"` CodeChallengeMethodsSupported []CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"`
//ServiceDocumentation is a URL where developers can get information about the OP and its usage.
ServiceDocumentation string `json:"service_documentation,omitempty"` ServiceDocumentation string `json:"service_documentation,omitempty"`
//ClaimsLocalesSupported contains a list of BCP47 language tag values that the OP supports for values of Claims returned.
ClaimsLocalesSupported []language.Tag `json:"claims_locales_supported,omitempty"` ClaimsLocalesSupported []language.Tag `json:"claims_locales_supported,omitempty"`
//UILocalesSupported contains a list of BCP47 language tag values that the OP supports for the user interface.
UILocalesSupported []language.Tag `json:"ui_locales_supported,omitempty"` UILocalesSupported []language.Tag `json:"ui_locales_supported,omitempty"`
//RequestParameterSupported specifies whether the OP supports use of the `request` parameter. If omitted, the default value is false.
RequestParameterSupported bool `json:"request_parameter_supported,omitempty"` RequestParameterSupported bool `json:"request_parameter_supported,omitempty"`
RequestURIParameterSupported bool `json:"request_uri_parameter_supported"` //no omitempty because: If omitted, the default value is true
//RequestURIParameterSupported specifies whether the OP supports use of the `request_uri` parameter. If omitted, the default value is true. (therefore no omitempty)
RequestURIParameterSupported bool `json:"request_uri_parameter_supported"`
//RequireRequestURIRegistration specifies whether the OP requires any `request_uri` to be pre-registered using the request_uris registration parameter. If omitted, the default value is false.
RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"` RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"`
//OPPolicyURI is a URL the OP provides to the person registering the Client to read about the OP's requirements on how the RP can use the data provided by the OP.
OPPolicyURI string `json:"op_policy_uri,omitempty"` OPPolicyURI string `json:"op_policy_uri,omitempty"`
//OPTermsOfServiceURI is a URL the OpenID Provider provides to the person registering the Client to read about OpenID Provider's terms of service.
OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"` OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"`
} }

139
pkg/oidc/error.go Normal file
View file

@ -0,0 +1,139 @@
package oidc
import (
"errors"
"fmt"
)
type errorType string
const (
InvalidRequest errorType = "invalid_request"
InvalidScope errorType = "invalid_scope"
InvalidClient errorType = "invalid_client"
InvalidGrant errorType = "invalid_grant"
UnauthorizedClient errorType = "unauthorized_client"
UnsupportedGrantType errorType = "unsupported_grant_type"
ServerError errorType = "server_error"
InteractionRequired errorType = "interaction_required"
LoginRequired errorType = "login_required"
RequestNotSupported errorType = "request_not_supported"
)
var (
ErrInvalidRequest = func() *Error {
return &Error{
ErrorType: InvalidRequest,
}
}
ErrInvalidRequestRedirectURI = func() *Error {
return &Error{
ErrorType: InvalidRequest,
redirectDisabled: true,
}
}
ErrInvalidScope = func() *Error {
return &Error{
ErrorType: InvalidScope,
}
}
ErrInvalidClient = func() *Error {
return &Error{
ErrorType: InvalidClient,
}
}
ErrInvalidGrant = func() *Error {
return &Error{
ErrorType: InvalidGrant,
}
}
ErrUnauthorizedClient = func() *Error {
return &Error{
ErrorType: UnauthorizedClient,
}
}
ErrUnsupportedGrantType = func() *Error {
return &Error{
ErrorType: UnsupportedGrantType,
}
}
ErrServerError = func() *Error {
return &Error{
ErrorType: ServerError,
}
}
ErrInteractionRequired = func() *Error {
return &Error{
ErrorType: InteractionRequired,
}
}
ErrLoginRequired = func() *Error {
return &Error{
ErrorType: LoginRequired,
}
}
ErrRequestNotSupported = func() *Error {
return &Error{
ErrorType: RequestNotSupported,
}
}
)
type Error struct {
Parent error `json:"-" schema:"-"`
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
redirectDisabled bool `schema:"-"`
}
func (e *Error) Error() string {
message := "ErrorType=" + string(e.ErrorType)
if e.Description != "" {
message += " Description=" + e.Description
}
if e.Parent != nil {
message += " Parent=" + e.Parent.Error()
}
return message
}
func (e *Error) Unwrap() error {
return e.Parent
}
func (e *Error) Is(target error) bool {
t, ok := target.(*Error)
if !ok {
return false
}
return e.ErrorType == t.ErrorType &&
(e.Description == t.Description || t.Description == "") &&
(e.State == t.State || t.State == "")
}
func (e *Error) WithParent(err error) *Error {
e.Parent = err
return e
}
func (e *Error) WithDescription(desc string, args ...interface{}) *Error {
e.Description = fmt.Sprintf(desc, args...)
return e
}
func (e *Error) IsRedirectDisabled() bool {
return e.redirectDisabled
}
// DefaultToServerError checks if the error is an Error
// if not the provided error will be wrapped into a ServerError
func DefaultToServerError(err error, description string) *Error {
oauth := new(Error)
if ok := errors.As(err, &oauth); !ok {
oauth.ErrorType = ServerError
oauth.Description = description
oauth.Parent = err
}
return oauth
}

View file

@ -42,181 +42,181 @@ type introspectionResponse struct {
claims map[string]interface{} claims map[string]interface{}
} }
func (u *introspectionResponse) IsActive() bool { func (i *introspectionResponse) IsActive() bool {
return u.Active return i.Active
} }
func (u *introspectionResponse) SetScopes(scope []string) { func (i *introspectionResponse) SetScopes(scope []string) {
u.Scope = scope i.Scope = scope
} }
func (u *introspectionResponse) SetClientID(id string) { func (i *introspectionResponse) SetClientID(id string) {
u.ClientID = id i.ClientID = id
} }
func (u *introspectionResponse) GetSubject() string { func (i *introspectionResponse) GetSubject() string {
return u.Subject return i.Subject
} }
func (u *introspectionResponse) GetName() string { func (i *introspectionResponse) GetName() string {
return u.Name return i.Name
} }
func (u *introspectionResponse) GetGivenName() string { func (i *introspectionResponse) GetGivenName() string {
return u.GivenName return i.GivenName
} }
func (u *introspectionResponse) GetFamilyName() string { func (i *introspectionResponse) GetFamilyName() string {
return u.FamilyName return i.FamilyName
} }
func (u *introspectionResponse) GetMiddleName() string { func (i *introspectionResponse) GetMiddleName() string {
return u.MiddleName return i.MiddleName
} }
func (u *introspectionResponse) GetNickname() string { func (i *introspectionResponse) GetNickname() string {
return u.Nickname return i.Nickname
} }
func (u *introspectionResponse) GetProfile() string { func (i *introspectionResponse) GetProfile() string {
return u.Profile return i.Profile
} }
func (u *introspectionResponse) GetPicture() string { func (i *introspectionResponse) GetPicture() string {
return u.Picture return i.Picture
} }
func (u *introspectionResponse) GetWebsite() string { func (i *introspectionResponse) GetWebsite() string {
return u.Website return i.Website
} }
func (u *introspectionResponse) GetGender() Gender { func (i *introspectionResponse) GetGender() Gender {
return u.Gender return i.Gender
} }
func (u *introspectionResponse) GetBirthdate() string { func (i *introspectionResponse) GetBirthdate() string {
return u.Birthdate return i.Birthdate
} }
func (u *introspectionResponse) GetZoneinfo() string { func (i *introspectionResponse) GetZoneinfo() string {
return u.Zoneinfo return i.Zoneinfo
} }
func (u *introspectionResponse) GetLocale() language.Tag { func (i *introspectionResponse) GetLocale() language.Tag {
return u.Locale return i.Locale
} }
func (u *introspectionResponse) GetPreferredUsername() string { func (i *introspectionResponse) GetPreferredUsername() string {
return u.PreferredUsername return i.PreferredUsername
} }
func (u *introspectionResponse) GetEmail() string { func (i *introspectionResponse) GetEmail() string {
return u.Email return i.Email
} }
func (u *introspectionResponse) IsEmailVerified() bool { func (i *introspectionResponse) IsEmailVerified() bool {
return bool(u.EmailVerified) return bool(i.EmailVerified)
} }
func (u *introspectionResponse) GetPhoneNumber() string { func (i *introspectionResponse) GetPhoneNumber() string {
return u.PhoneNumber return i.PhoneNumber
} }
func (u *introspectionResponse) IsPhoneNumberVerified() bool { func (i *introspectionResponse) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified return i.PhoneNumberVerified
} }
func (u *introspectionResponse) GetAddress() UserInfoAddress { func (i *introspectionResponse) GetAddress() UserInfoAddress {
return u.Address return i.Address
} }
func (u *introspectionResponse) GetClaim(key string) interface{} { func (i *introspectionResponse) GetClaim(key string) interface{} {
return u.claims[key] return i.claims[key]
} }
func (u *introspectionResponse) SetActive(active bool) { func (i *introspectionResponse) SetActive(active bool) {
u.Active = active i.Active = active
} }
func (u *introspectionResponse) SetSubject(sub string) { func (i *introspectionResponse) SetSubject(sub string) {
u.Subject = sub i.Subject = sub
} }
func (u *introspectionResponse) SetName(name string) { func (i *introspectionResponse) SetName(name string) {
u.Name = name i.Name = name
} }
func (u *introspectionResponse) SetGivenName(name string) { func (i *introspectionResponse) SetGivenName(name string) {
u.GivenName = name i.GivenName = name
} }
func (u *introspectionResponse) SetFamilyName(name string) { func (i *introspectionResponse) SetFamilyName(name string) {
u.FamilyName = name i.FamilyName = name
} }
func (u *introspectionResponse) SetMiddleName(name string) { func (i *introspectionResponse) SetMiddleName(name string) {
u.MiddleName = name i.MiddleName = name
} }
func (u *introspectionResponse) SetNickname(name string) { func (i *introspectionResponse) SetNickname(name string) {
u.Nickname = name i.Nickname = name
} }
func (u *introspectionResponse) SetUpdatedAt(date time.Time) { func (i *introspectionResponse) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date) i.UpdatedAt = Time(date)
} }
func (u *introspectionResponse) SetProfile(profile string) { func (i *introspectionResponse) SetProfile(profile string) {
u.Profile = profile i.Profile = profile
} }
func (u *introspectionResponse) SetPicture(picture string) { func (i *introspectionResponse) SetPicture(picture string) {
u.Picture = picture i.Picture = picture
} }
func (u *introspectionResponse) SetWebsite(website string) { func (i *introspectionResponse) SetWebsite(website string) {
u.Website = website i.Website = website
} }
func (u *introspectionResponse) SetGender(gender Gender) { func (i *introspectionResponse) SetGender(gender Gender) {
u.Gender = gender i.Gender = gender
} }
func (u *introspectionResponse) SetBirthdate(birthdate string) { func (i *introspectionResponse) SetBirthdate(birthdate string) {
u.Birthdate = birthdate i.Birthdate = birthdate
} }
func (u *introspectionResponse) SetZoneinfo(zoneInfo string) { func (i *introspectionResponse) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo i.Zoneinfo = zoneInfo
} }
func (u *introspectionResponse) SetLocale(locale language.Tag) { func (i *introspectionResponse) SetLocale(locale language.Tag) {
u.Locale = locale i.Locale = locale
} }
func (u *introspectionResponse) SetPreferredUsername(name string) { func (i *introspectionResponse) SetPreferredUsername(name string) {
u.PreferredUsername = name i.PreferredUsername = name
} }
func (u *introspectionResponse) SetEmail(email string, verified bool) { func (i *introspectionResponse) SetEmail(email string, verified bool) {
u.Email = email i.Email = email
u.EmailVerified = boolString(verified) i.EmailVerified = boolString(verified)
} }
func (u *introspectionResponse) SetPhone(phone string, verified bool) { func (i *introspectionResponse) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone i.PhoneNumber = phone
u.PhoneNumberVerified = verified i.PhoneNumberVerified = verified
} }
func (u *introspectionResponse) SetAddress(address UserInfoAddress) { func (i *introspectionResponse) SetAddress(address UserInfoAddress) {
u.Address = address i.Address = address
} }
func (u *introspectionResponse) AppendClaims(key string, value interface{}) { func (i *introspectionResponse) AppendClaims(key string, value interface{}) {
if u.claims == nil { if i.claims == nil {
u.claims = make(map[string]interface{}) i.claims = make(map[string]interface{})
} }
u.claims[key] = value i.claims[key] = value
} }
func (i *introspectionResponse) MarshalJSON() ([]byte, error) { func (i *introspectionResponse) MarshalJSON() ([]byte, error) {

6
pkg/oidc/revocation.go Normal file
View file

@ -0,0 +1,6 @@
package oidc
type RevocationRequest struct {
Token string `schema:"token"`
TokenTypeHint string `schema:"token_type_hint"`
}

View file

@ -1,10 +1,7 @@
package oidc package oidc
import ( import (
"crypto/rsa"
"crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"time" "time"
@ -12,7 +9,8 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/crypto"
"github.com/caos/oidc/pkg/http"
) )
const ( const (
@ -188,7 +186,7 @@ func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return utils.ConcatenateJSON(b, info) return http.ConcatenateJSON(b, info)
} }
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error { func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
@ -325,7 +323,7 @@ func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg return t.signatureAlg
} }
//SetSignatureAlgorithm implements the IDTokenClaims interface //SetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetAccessTokenHash(hash string) { func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash t.AccessTokenHash = hash
} }
@ -375,7 +373,7 @@ func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return utils.ConcatenateJSON(b, info) return http.ConcatenateJSON(b, info)
} }
func (t *idTokenClaims) UnmarshalJSON(data []byte) error { func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
@ -572,12 +570,12 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte,
} }
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
hash, err := utils.GetHashAlgorithm(sigAlgorithm) hash, err := crypto.GetHashAlgorithm(sigAlgorithm)
if err != nil { if err != nil {
return "", err return "", err
} }
return utils.HashString(hash, claim, true), nil return crypto.HashString(hash, claim, true), nil
} }
func AppendClientIDToAudience(clientID string, audience []string) []string { func AppendClientIDToAudience(clientID string, audience []string) []string {
@ -590,7 +588,7 @@ func AppendClientIDToAudience(clientID string, audience []string) []string {
} }
func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) { func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error) {
privateKey, err := bytesToPrivateKey(assertion.GetPrivateKey()) privateKey, err := crypto.BytesToPrivateKey(assertion.GetPrivateKey())
if err != nil { if err != nil {
return "", err return "", err
} }
@ -613,21 +611,3 @@ func GenerateJWTProfileToken(assertion JWTProfileAssertionClaims) (string, error
} }
return signedAssertion.CompactSerialize() return signedAssertion.CompactSerialize()
} }
func bytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(priv)
enc := x509.IsEncryptedPEMBlock(block)
b := block.Bytes
var err error
if enc {
b, err = x509.DecryptPEMBlock(block, nil)
if err != nil {
return nil, err
}
}
key, err := x509.ParsePKCS1PrivateKey(b)
if err != nil {
return nil, err
}
return key, nil
}

View file

@ -12,7 +12,7 @@ const (
//GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow //GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
GrantTypeCode GrantType = "authorization_code" GrantTypeCode GrantType = "authorization_code"
//GrantTypeCode defines the grant_type `refresh_token` used for the Token Request in the Refresh Token Flow //GrantTypeRefreshToken defines the grant_type `refresh_token` used for the Token Request in the Refresh Token Flow
GrantTypeRefreshToken GrantType = "refresh_token" GrantTypeRefreshToken GrantType = "refresh_token"
//GrantTypeBearer defines the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant //GrantTypeBearer defines the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
@ -183,7 +183,7 @@ func (j *JWTTokenRequest) GetSubject() string {
return j.Subject return j.Subject
} }
//GetSubject implements the TokenRequest interface //GetScopes implements the TokenRequest interface
func (j *JWTTokenRequest) GetScopes() []string { func (j *JWTTokenRequest) GetScopes() []string {
return j.Scopes return j.Scopes
} }

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
) )
type Audience []string type Audience []string
@ -66,6 +67,8 @@ type Prompt SpaceDelimitedArray
type ResponseType string type ResponseType string
type ResponseMode string
func (s SpaceDelimitedArray) Encode() string { func (s SpaceDelimitedArray) Encode() string {
return strings.Join(s, " ") return strings.Join(s, " ")
} }
@ -106,3 +109,16 @@ func (t *Time) UnmarshalJSON(data []byte) error {
func (t *Time) MarshalJSON() ([]byte, error) { func (t *Time) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Time(*t).UTC().Unix()) return json.Marshal(time.Time(*t).UTC().Unix())
} }
type RequestObject struct {
Issuer string `json:"iss"`
Audience Audience `json:"aud"`
AuthRequest
}
func (r *RequestObject) GetIssuer() string {
return r.Issuer
}
func (r *RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
}

View file

@ -339,20 +339,20 @@ func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, fo
} }
} }
func (i *userinfo) MarshalJSON() ([]byte, error) { func (u *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo type Alias userinfo
a := &struct { a := &struct {
*Alias *Alias
Locale interface{} `json:"locale,omitempty"` Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"` UpdatedAt int64 `json:"updated_at,omitempty"`
}{ }{
Alias: (*Alias)(i), Alias: (*Alias)(u),
} }
if !i.Locale.IsRoot() { if !u.Locale.IsRoot() {
a.Locale = i.Locale a.Locale = u.Locale
} }
if !time.Time(i.UpdatedAt).IsZero() { if !time.Time(u.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix() a.UpdatedAt = time.Time(u.UpdatedAt).Unix()
} }
b, err := json.Marshal(a) b, err := json.Marshal(a)
@ -360,34 +360,34 @@ func (i *userinfo) MarshalJSON() ([]byte, error) {
return nil, err return nil, err
} }
if len(i.claims) == 0 { if len(u.claims) == 0 {
return b, nil return b, nil
} }
err = json.Unmarshal(b, &i.claims) err = json.Unmarshal(b, &u.claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims) return nil, fmt.Errorf("jws: invalid map of custom claims %v", u.claims)
} }
return json.Marshal(i.claims) return json.Marshal(u.claims)
} }
func (i *userinfo) UnmarshalJSON(data []byte) error { func (u *userinfo) UnmarshalJSON(data []byte) error {
type Alias userinfo type Alias userinfo
a := &struct { a := &struct {
Address *userInfoAddress `json:"address,omitempty"` Address *userInfoAddress `json:"address,omitempty"`
*Alias *Alias
UpdatedAt int64 `json:"update_at,omitempty"` UpdatedAt int64 `json:"update_at,omitempty"`
}{ }{
Alias: (*Alias)(i), Alias: (*Alias)(u),
} }
if err := json.Unmarshal(data, &a); err != nil { if err := json.Unmarshal(data, &a); err != nil {
return err return err
} }
i.Address = a.Address u.Address = a.Address
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) u.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
if err := json.Unmarshal(data, &i.claims); err != nil { if err := json.Unmarshal(data, &u.claims); err != nil {
return err return err
} }

View file

@ -12,7 +12,7 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" str "github.com/caos/oidc/pkg/strings"
) )
type Claims interface { type Claims interface {
@ -25,6 +25,10 @@ type Claims interface {
GetAuthenticationContextClassReference() string GetAuthenticationContextClassReference() string
GetAuthTime() time.Time GetAuthTime() time.Time
GetAuthorizedParty() string GetAuthorizedParty() string
ClaimsSignature
}
type ClaimsSignature interface {
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
} }
@ -61,10 +65,10 @@ type Verifier interface {
type ACRVerifier func(string) error type ACRVerifier func(string) error
//DefaultACRVerifier implements `ACRVerifier` returning an error //DefaultACRVerifier implements `ACRVerifier` returning an error
//if non of the provided values matches the acr claim //if none of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier { func DefaultACRVerifier(possibleValues []string) ACRVerifier {
return func(acr string) error { return func(acr string) error {
if !utils.Contains(possibleValues, acr) { if !str.Contains(possibleValues, acr) {
return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr) return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
} }
return nil return nil
@ -103,7 +107,7 @@ func CheckIssuer(claims Claims, issuer string) error {
} }
func CheckAudience(claims Claims, clientID string) error { func CheckAudience(claims Claims, clientID string) error {
if !utils.Contains(claims.GetAudience(), clientID) { if !str.Contains(claims.GetAudience(), clientID) {
return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID) return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
} }
@ -123,7 +127,7 @@ func CheckAuthorizedParty(claims Claims, clientID string) error {
return nil return nil
} }
func CheckSignature(ctx context.Context, token string, payload []byte, claims Claims, supportedSigAlgs []string, set KeySet) error { func CheckSignature(ctx context.Context, token string, payload []byte, claims ClaimsSignature, supportedSigAlgs []string, set KeySet) error {
jws, err := jose.ParseSigned(token) jws, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return ErrParse return ErrParse
@ -138,7 +142,7 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl
if len(supportedSigAlgs) == 0 { if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{"RS256"} supportedSigAlgs = []string{"RS256"}
} }
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) { if !str.Contains(supportedSigAlgs, sig.Header.Algorithm) {
return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm) return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm)
} }

View file

@ -2,7 +2,6 @@ package op
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -11,8 +10,9 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" str "github.com/caos/oidc/pkg/strings"
) )
type AuthRequest interface { type AuthRequest interface {
@ -26,6 +26,7 @@ type AuthRequest interface {
GetNonce() string GetNonce() string
GetRedirectURI() string GetRedirectURI() string
GetResponseType() oidc.ResponseType GetResponseType() oidc.ResponseType
GetResponseMode() oidc.ResponseMode
GetScopes() []string GetScopes() []string
GetState() string GetState() string
GetSubject() string GetSubject() string
@ -34,16 +35,17 @@ type AuthRequest interface {
type Authorizer interface { type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() utils.Decoder Decoder() httphelper.Decoder
Encoder() utils.Encoder Encoder() httphelper.Encoder
Signer() Signer Signer() Signer
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier() IDTokenHintVerifier
Crypto() Crypto Crypto() Crypto
Issuer() string Issuer() string
RequestObjectSupported() bool
} }
//AuthorizeValidator is an extension of Authorizer interface //AuthorizeValidator is an extension of Authorizer interface
//implementing it's own validation mechanism for the auth request //implementing its own validation mechanism for the auth request
type AuthorizeValidator interface { type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
@ -69,6 +71,13 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
}
validation := ValidateAuthRequest validation := ValidateAuthRequest
if validater, ok := authorizer.(AuthorizeValidator); ok { if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest validation = validater.ValidateAuthRequest
@ -78,33 +87,114 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
if authReq.RequestParam != "" {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID) req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return return
} }
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil { if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder()) AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return return
} }
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
} }
//ParseAuthorizeRequest parsed the http request into a oidc.AuthRequest //ParseAuthorizeRequest parsed the http request into an oidc.AuthRequest
func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) { func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AuthRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("cannot parse form") return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
} }
authReq := new(oidc.AuthRequest) authReq := new(oidc.AuthRequest)
err = decoder.Decode(authReq, r.Form) err = decoder.Decode(authReq, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)) return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse auth request").WithParent(err)
} }
return authReq, nil return authReq, nil
} }
//ParseRequestObject parse the `request` parameter, validates the token including the signature
//and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil {
return nil, err
}
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest()
}
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest()
}
if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest()
}
if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest()
}
keySet := &jwtProfileKeySet{storage, requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err
}
CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil
}
//CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
//and clears the `RequestParam` of the auth request
func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) {
if str.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 {
authReq.Scopes = requestObject.Scopes
}
if requestObject.RedirectURI != "" {
authReq.RedirectURI = requestObject.RedirectURI
}
if requestObject.State != "" {
authReq.State = requestObject.State
}
if requestObject.ResponseMode != "" {
authReq.ResponseMode = requestObject.ResponseMode
}
if requestObject.Nonce != "" {
authReq.Nonce = requestObject.Nonce
}
if requestObject.Display != "" {
authReq.Display = requestObject.Display
}
if len(requestObject.Prompt) > 0 {
authReq.Prompt = requestObject.Prompt
}
if requestObject.MaxAge != nil {
authReq.MaxAge = requestObject.MaxAge
}
if len(requestObject.UILocales) > 0 {
authReq.UILocales = requestObject.UILocales
}
if requestObject.IDTokenHint != "" {
authReq.IDTokenHint = requestObject.IDTokenHint
}
if requestObject.LoginHint != "" {
authReq.LoginHint = requestObject.LoginHint
}
if len(requestObject.ACRValues) > 0 {
authReq.ACRValues = requestObject.ACRValues
}
if requestObject.CodeChallenge != "" {
authReq.CodeChallenge = requestObject.CodeChallenge
}
if requestObject.CodeChallengeMethod != "" {
authReq.CodeChallengeMethod = requestObject.CodeChallengeMethod
}
authReq.RequestParam = ""
}
//ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed //ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
@ -113,7 +203,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
} }
client, err := storage.GetClientByClientID(ctx, authReq.ClientID) client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil { if err != nil {
return "", ErrServerError(err.Error()) return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
} }
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil { if err != nil {
@ -132,7 +222,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) { func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
for _, prompt := range prompts { for _, prompt := range prompts {
if prompt == oidc.PromptNone && len(prompts) > 1 { if prompt == oidc.PromptNone && len(prompts) > 1 {
return nil, ErrInvalidRequest("The prompt parameter `none` must only be used as a single value") return nil, oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value")
} }
if prompt == oidc.PromptLogin { if prompt == oidc.PromptLogin {
maxAge = oidc.NewMaxAge(0) maxAge = oidc.NewMaxAge(0)
@ -144,7 +234,9 @@ func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error)
//ValidateAuthReqScopes validates the passed scopes //ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 { if len(scopes) == 0 {
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.") return nil, oidc.ErrInvalidRequest().
WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
"If you have any questions, you may contact the administrator of the application.")
} }
openID := false openID := false
for i := len(scopes) - 1; i >= 0; i-- { for i := len(scopes) - 1; i >= 0; i-- {
@ -165,7 +257,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
} }
} }
if !openID { if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.") return nil, oidc.ErrInvalidScope().WithDescription("The scope openid is missing in your request. " +
"Please ensure the scope openid is added to the request. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return scopes, nil return scopes, nil
@ -174,19 +268,23 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error { func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error {
if uri == "" { if uri == "" {
return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " +
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
} }
if strings.HasPrefix(uri, "https://") { if strings.HasPrefix(uri, "https://") {
if !utils.Contains(client.RedirectURIs(), uri) { if !str.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().
WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return nil return nil
} }
if client.ApplicationType() == ApplicationTypeNative { if client.ApplicationType() == ApplicationTypeNative {
return validateAuthReqRedirectURINative(client, uri, responseType) return validateAuthReqRedirectURINative(client, uri, responseType)
} }
if !utils.Contains(client.RedirectURIs(), uri) { if !str.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if strings.HasPrefix(uri, "http://") { if strings.HasPrefix(uri, "http://") {
if client.DevMode() { if client.DevMode() {
@ -195,23 +293,27 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) { if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) {
return nil return nil
} }
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return ErrInvalidRequest("This client's redirect_uri is using a custom schema and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is using a custom schema and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
//ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type //ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://") isCustomSchema := !strings.HasPrefix(uri, "http://")
if utils.Contains(client.RedirectURIs(), uri) { if str.Contains(client.RedirectURIs(), uri) {
if isLoopback || isCustomSchema { if isLoopback || isCustomSchema {
return nil return nil
} }
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if !isLoopback { if !isLoopback {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
for _, uri := range client.RedirectURIs() { for _, uri := range client.RedirectURIs() {
redirectURI, ok := HTTPLoopbackOrLocalhost(uri) redirectURI, ok := HTTPLoopbackOrLocalhost(uri)
@ -219,7 +321,8 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi
return nil return nil
} }
} }
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration." +
" If you have any questions, you may contact the administrator of the application.")
} }
func equalURI(url1, url2 *url.URL) bool { func equalURI(url1, url2 *url.URL) bool {
@ -241,10 +344,12 @@ func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) {
//ValidateAuthReqResponseType validates the passed response_type to the registered response types //ValidateAuthReqResponseType validates the passed response_type to the registered response types
func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error { func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error {
if responseType == "" { if responseType == "" {
return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.") return oidc.ErrInvalidRequest().WithDescription("The response type is missing in your request. " +
"If you have any questions, you may contact the administrator of the application.")
} }
if !ContainsResponseType(client.ResponseTypes(), responseType) { if !ContainsResponseType(client.ResponseTypes(), responseType) {
return ErrInvalidRequest("The requested response type is missing in the client configuration. If you have any questions, you may contact the administrator of the application.") return oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return nil return nil
} }
@ -257,7 +362,8 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
} }
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier) claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
if err != nil { if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.") return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
"If you have any questions, you may contact the administrator of the application.")
} }
return claims.GetSubject(), nil return claims.GetSubject(), nil
} }
@ -279,7 +385,9 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
return return
} }
if !authReq.Done() { if !authReq.Done() {
AuthRequestError(w, r, authReq, ErrInteractionRequired("Unfortunately, the user may is not logged in and/or additional interaction is required."), authorizer.Encoder()) AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder())
return return
} }
AuthResponse(authReq, authorizer, w, r) AuthResponse(authReq, authorizer, w, r)
@ -306,9 +414,17 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) codeResponse := struct {
if authReq.GetState() != "" { code string
callback = callback + "&state=" + authReq.GetState() state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
} }
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
@ -321,12 +437,11 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil { if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return return
} }
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
http.Redirect(w, r, callback, http.StatusFound) http.Redirect(w, r, callback, http.StatusFound)
} }
@ -346,3 +461,22 @@ func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Sto
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID()) return crypto.Encrypt(authReq.GetID())
} }
//AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values
//depending on the response_mode and response_type
func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response interface{}, encoder httphelper.Encoder) (string, error) {
params, err := httphelper.URLEncodeResponse(response, encoder)
if err != nil {
return "", oidc.ErrServerError().WithParent(err)
}
if responseMode == oidc.ResponseModeQuery {
return redirectURI + "?" + params, nil
}
if responseMode == oidc.ResponseModeFragment {
return redirectURI + "#" + params, nil
}
if responseType == "" || responseType == oidc.ResponseTypeCode {
return redirectURI + "?" + params, nil
}
return redirectURI + "#" + params, nil
}

View file

@ -1,6 +1,8 @@
package op_test package op_test
import ( import (
"context"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -11,10 +13,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock" "github.com/caos/oidc/pkg/op/mock"
"github.com/caos/oidc/pkg/utils"
) )
// //
@ -75,7 +77,7 @@ import (
func TestParseAuthorizeRequest(t *testing.T) { func TestParseAuthorizeRequest(t *testing.T) {
type args struct { type args struct {
r *http.Request r *http.Request
decoder utils.Decoder decoder httphelper.Decoder
} }
type res struct { type res struct {
want *oidc.AuthRequest want *oidc.AuthRequest
@ -101,7 +103,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
"decoding error", "decoding error",
args{ args{
&http.Request{URL: &url.URL{RawQuery: "unknown=value"}}, &http.Request{URL: &url.URL{RawQuery: "unknown=value"}},
func() utils.Decoder { func() httphelper.Decoder {
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(false) decoder.IgnoreUnknownKeys(false)
return decoder return decoder
@ -116,7 +118,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
"parsing ok", "parsing ok",
args{ args{
&http.Request{URL: &url.URL{RawQuery: "scope=openid"}}, &http.Request{URL: &url.URL{RawQuery: "scope=openid"}},
func() utils.Decoder { func() httphelper.Decoder {
decoder := schema.NewDecoder() decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(false) decoder.IgnoreUnknownKeys(false)
return decoder return decoder
@ -150,44 +152,138 @@ func TestValidateAuthRequest(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
wantErr bool wantErr error
}{ }{
//TODO:
// {
// "oauth2 spec"
// }
{ {
"scope missing fails", "scope missing fails",
args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"scope openid missing fails", "scope openid missing fails",
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidScope(),
}, },
{ {
"response_type missing fails", "response_type missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"client_id missing fails", "client_id missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
{ {
"redirect_uri missing fails", "redirect_uri missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil}, args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil},
true, oidc.ErrInvalidRequest(),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier) _, err := op.ValidateAuthRequest(context.TODO(), tt.args.authRequest, tt.args.storage, tt.args.verifier)
if (err != nil) != tt.wantErr { if tt.wantErr == nil && err != nil {
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
} }
if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.wantErr)
}
})
}
}
func TestValidateAuthReqPrompt(t *testing.T) {
type args struct {
prompts []string
maxAge *uint
}
type res struct {
maxAge *uint
err error
}
tests := []struct {
name string
args args
res res
}{
{
"no prompts and maxAge, ok",
args{
nil,
nil,
},
res{
nil,
nil,
},
},
{
"no prompts but maxAge, ok",
args{
nil,
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(10),
nil,
},
},
{
"prompt none, ok",
args{
[]string{"none"},
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(10),
nil,
},
},
{
"prompt none with others, err",
args{
[]string{"none", "login"},
oidc.NewMaxAge(10),
},
res{
nil,
oidc.ErrInvalidRequest(),
},
},
{
"prompt login, ok",
args{
[]string{"login"},
nil,
},
res{
oidc.NewMaxAge(0),
nil,
},
},
{
"prompt login with maxAge, ok",
args{
[]string{"login"},
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(0),
nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
maxAge, err := op.ValidateAuthReqPrompt(tt.args.prompts, tt.args.maxAge)
if tt.res.err == nil && err != nil {
t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.res.err != nil && !errors.Is(err, tt.res.err) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
}
assert.Equal(t, tt.res.maxAge, maxAge)
}) })
} }
} }
@ -465,6 +561,80 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
} }
} }
func TestLoopbackOrLocalhost(t *testing.T) {
type args struct {
url string
}
tests := []struct {
name string
args args
want bool
}{
{
"not parsable, false",
args{url: string('\n')},
false,
},
{
"not http, false",
args{url: "localhost/test"},
false,
},
{
"not http, false",
args{url: "http://localhost.com/test"},
false,
},
{
"v4 no port ok",
args{url: "http://127.0.0.1/test"},
true,
},
{
"v6 short no port ok",
args{url: "http://[::1]/test"},
true,
},
{
"v6 long no port ok",
args{url: "http://[0:0:0:0:0:0:0:1]/test"},
true,
},
{
"locahost no port ok",
args{url: "http://localhost/test"},
true,
},
{
"v4 with port ok",
args{url: "http://127.0.0.1:4200/test"},
true,
},
{
"v6 short with port ok",
args{url: "http://[::1]:4200/test"},
true,
},
{
"v6 long with port ok",
args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
true,
},
{
"localhost with port ok",
args{url: "http://localhost:4200/test"},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want {
t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want)
}
})
}
}
func TestValidateAuthReqResponseType(t *testing.T) { func TestValidateAuthReqResponseType(t *testing.T) {
type args struct { type args struct {
responseType oidc.ResponseType responseType oidc.ResponseType
@ -534,100 +704,122 @@ func TestRedirectToLogin(t *testing.T) {
} }
} }
func TestAuthorizeCallback(t *testing.T) { func TestAuthResponseURL(t *testing.T) {
type args struct { type args struct {
w http.ResponseWriter redirectURI string
r *http.Request responseType oidc.ResponseType
authorizer op.Authorizer responseMode oidc.ResponseMode
response interface{}
encoder httphelper.Encoder
} }
tests := []struct { type res struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
func TestAuthResponse(t *testing.T) {
type args struct {
authReq op.AuthRequest
authorizer op.Authorizer
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
})
}
}
func Test_LoopbackOrLocalhost(t *testing.T) {
type args struct {
url string url string
err error
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want bool res res
}{ }{
{ {
"v4 no port ok", "encoding error",
args{url: "http://127.0.0.1/test"}, args{
true, "uri",
oidc.ResponseTypeCode,
"",
map[string]interface{}{"test": "test"},
&mockEncoder{
errors.New("error encoding"),
},
},
res{
"",
oidc.ErrServerError(),
},
}, },
{ {
"v6 short no port ok", "response mode query",
args{url: "http://[::1]/test"}, args{
true, "uri",
oidc.ResponseTypeIDToken,
oidc.ResponseModeQuery,
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri?test=test",
nil,
},
}, },
{ {
"v6 long no port ok", "response mode fragment",
args{url: "http://[0:0:0:0:0:0:0:1]/test"}, args{
true, "uri",
oidc.ResponseTypeCode,
oidc.ResponseModeFragment,
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri#test=test",
nil,
},
}, },
{ {
"locahost no port ok", "response type code",
args{url: "http://localhost/test"}, args{
true, "uri",
oidc.ResponseTypeCode,
"",
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri?test=test",
nil,
},
}, },
{ {
"v4 with port ok", "response type id token",
args{url: "http://127.0.0.1:4200/test"}, args{
true, "uri",
oidc.ResponseTypeIDToken,
"",
map[string][]string{"test": {"test"}},
&mockEncoder{},
}, },
{ res{
"v6 short with port ok", "uri#test=test",
args{url: "http://[::1]:4200/test"}, nil,
true,
}, },
{
"v6 long with port ok",
args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
true,
},
{
"localhost with port ok",
args{url: "http://localhost:4200/test"},
true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want { got, err := op.AuthResponseURL(tt.args.redirectURI, tt.args.responseType, tt.args.responseMode, tt.args.response, tt.args.encoder)
t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want) if tt.res.err == nil && err != nil {
t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.res.err != nil && !errors.Is(err, tt.res.err) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
}
if got != tt.res.url {
t.Errorf("AuthResponseURL() got = %v, want %v", got, tt.res.url)
} }
}) })
} }
} }
type mockEncoder struct {
err error
}
func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error {
if m.err != nil {
return m.err
}
for s, strings := range src.(map[string][]string) {
dst[s] = strings
}
return nil
}

View file

@ -16,15 +16,23 @@ type Configuration interface {
TokenEndpoint() Endpoint TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint IntrospectionEndpoint() Endpoint
UserinfoEndpoint() Endpoint UserinfoEndpoint() Endpoint
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint KeysEndpoint() Endpoint
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
CodeMethodS256Supported() bool CodeMethodS256Supported() bool
AuthMethodPrivateKeyJWTSupported() bool AuthMethodPrivateKeyJWTSupported() bool
TokenEndpointSigningAlgorithmsSupported() []string
GrantTypeRefreshTokenSupported() bool GrantTypeRefreshTokenSupported() bool
GrantTypeTokenExchangeSupported() bool GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool GrantTypeJWTAuthorizationSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
RevocationEndpointSigningAlgorithmsSupported() []string
RequestObjectSupported() bool
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag SupportedUILocales() []language.Tag
} }

View file

@ -61,6 +61,7 @@ func TestValidateIssuer(t *testing.T) {
}, },
} }
//ensure env is not set //ensure env is not set
//nolint:errcheck
os.Unsetenv(OidcDevMode) os.Unsetenv(OidcDevMode)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -86,6 +87,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
false, false,
}, },
} }
//nolint:errcheck
os.Setenv(OidcDevMode, "true") os.Setenv(OidcDevMode, "true")
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -1,7 +1,7 @@
package op package op
import ( import (
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/crypto"
) )
type Crypto interface { type Crypto interface {
@ -18,9 +18,9 @@ func NewAESCrypto(key [32]byte) Crypto {
} }
func (c *aesCrypto) Encrypt(s string) (string, error) { func (c *aesCrypto) Encrypt(s string) (string, error) {
return utils.EncryptAES(s, c.key) return crypto.EncryptAES(s, c.key)
} }
func (c *aesCrypto) Decrypt(s string) (string, error) { func (c *aesCrypto) Decrypt(s string) (string, error) {
return utils.DecryptAES(s, c.key) return crypto.DecryptAES(s, c.key)
} }

View file

@ -3,8 +3,8 @@ package op
import ( import (
"net/http" "net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) { func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) {
@ -14,7 +14,7 @@ func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http
} }
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
utils.MarshalJSON(w, config) httphelper.MarshalJSON(w, config)
} }
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration { func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
@ -24,6 +24,7 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()), IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()),
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()), EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c), ScopesSupported: Scopes(c),
@ -31,11 +32,17 @@ func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfigurati
GrantTypesSupported: GrantTypes(c), GrantTypesSupported: GrantTypes(c),
SubjectTypesSupported: SubjectTypes(c), SubjectTypesSupported: SubjectTypes(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s), IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c), IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c),
ClaimsSupported: SupportedClaims(c), ClaimsSupported: SupportedClaims(c),
CodeChallengeMethodsSupported: CodeChallengeMethods(c), CodeChallengeMethodsSupported: CodeChallengeMethods(c),
UILocalesSupported: c.SupportedUILocales(), UILocalesSupported: c.SupportedUILocales(),
RequestParameterSupported: c.RequestObjectSupported(),
} }
} }
@ -45,6 +52,7 @@ var DefaultSupportedScopes = []string{
oidc.ScopeEmail, oidc.ScopeEmail,
oidc.ScopePhone, oidc.ScopePhone,
oidc.ScopeAddress, oidc.ScopeAddress,
oidc.ScopeOfflineAccess,
} }
func Scopes(c Configuration) []string { func Scopes(c Configuration) []string {
@ -127,6 +135,13 @@ func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
return authMethods return authMethods
} }
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod { func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{ authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic, oidc.AuthMethodBasic,
@ -137,6 +152,20 @@ func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
return authMethods return authMethods
} }
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
codeMethods := make([]oidc.CodeChallengeMethod, 0, 1) codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
if c.CodeMethodS256Supported() { if c.CodeMethodS256Supported() {
@ -144,3 +173,24 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
} }
return codeMethods return codeMethods
} }
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}

View file

@ -37,7 +37,10 @@ func TestDiscover(t *testing.T) {
op.Discover(tt.args.w, tt.args.config) op.Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com","request_uri_parameter_supported":false}`, rec.Body.String()) require.Equal(t,
`{"issuer":"https://issuer.com","request_uri_parameter_supported":false}
`,
rec.Body.String())
}) })
} }
} }

View file

@ -1,105 +1,46 @@
package op package op
import ( import (
"fmt"
"net/http" "net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
const (
InvalidRequest errorType = "invalid_request"
InvalidRequestURI errorType = "invalid_request_uri"
InteractionRequired errorType = "interaction_required"
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequestURI,
Description: description,
redirectDisabled: true,
}
}
ErrInteractionRequired = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InteractionRequired,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface { type ErrAuthRequest interface {
GetRedirectURI() string GetRedirectURI() string
GetResponseType() oidc.ResponseType GetResponseType() oidc.ResponseType
GetState() string GetState() string
} }
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) { func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
if authReq == nil { if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
e, ok := err.(*OAuthError) e := oidc.DefaultToServerError(err, err.Error())
if !ok { if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.State = authReq.GetState()
if authReq.GetRedirectURI() == "" || e.redirectDisabled {
http.Error(w, e.Description, http.StatusBadRequest) http.Error(w, e.Description, http.StatusBadRequest)
return return
} }
params, err := utils.URLEncodeResponse(e, encoder) e.State = authReq.GetState()
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
url := authReq.GetRedirectURI()
responseType := authReq.GetResponseType()
if responseType == "" || responseType == oidc.ResponseTypeCode {
url += "?" + params
} else {
url += "#" + params
}
http.Redirect(w, r, url, http.StatusFound) http.Redirect(w, r, url, http.StatusFound)
} }
func RequestError(w http.ResponseWriter, r *http.Request, err error) { func RequestError(w http.ResponseWriter, r *http.Request, err error) {
e, ok := err.(*OAuthError) e := oidc.DefaultToServerError(err, err.Error())
if !ok { status := http.StatusBadRequest
e = new(OAuthError) if e.ErrorType == oidc.InvalidClient {
e.ErrorType = ServerError status = 401
e.Description = err.Error()
} }
w.WriteHeader(http.StatusBadRequest) httphelper.MarshalJSONWithStatus(w, e, status)
utils.MarshalJSON(w, e)
}
type OAuthError struct {
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
redirectDisabled bool `json:"-" schema:"-"`
}
func (e *OAuthError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
} }

View file

@ -6,7 +6,7 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils" httphelper "github.com/caos/oidc/pkg/http"
) )
type KeyProvider interface { type KeyProvider interface {
@ -22,9 +22,8 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.GetKeySet(r.Context()) keySet, err := k.GetKeySet(r.Context())
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
utils.MarshalJSON(w, err)
return return
} }
utils.MarshalJSON(w, keySet) httphelper.MarshalJSON(w, keySet)
} }

100
pkg/op/keys_test.go Normal file
View file

@ -0,0 +1,100 @@
package op_test
import (
"crypto/rsa"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock"
)
func TestKeys(t *testing.T) {
type args struct {
k op.KeyProvider
}
type res struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
args args
res res
}{
{
name: "error",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
return m
}(),
},
res: res{
statusCode: http.StatusInternalServerError,
contentType: "application/json",
body: `{"error":"server_error"}
`,
},
},
{
name: "empty list",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
},
},
{
name: "list",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(
&jose.JSONWebKeySet{Keys: []jose.JSONWebKey{
{
Key: &rsa.PublicKey{
N: big.NewInt(1),
E: 1,
},
KeyID: "id",
},
}},
nil,
)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
op.Keys(w, httptest.NewRequest("GET", "/keys", nil), tt.args.k)
assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
assert.Equal(t, tt.res.contentType, w.Header().Get("content-type"))
assert.Equal(t, tt.res.body, w.Body.String())
})
}
}

View file

@ -7,8 +7,8 @@ package mock
import ( import (
reflect "reflect" reflect "reflect"
http "github.com/caos/oidc/pkg/http"
op "github.com/caos/oidc/pkg/op" op "github.com/caos/oidc/pkg/op"
utils "github.com/caos/oidc/pkg/utils"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
} }
// Decoder mocks base method. // Decoder mocks base method.
func (m *MockAuthorizer) Decoder() utils.Decoder { func (m *MockAuthorizer) Decoder() http.Decoder {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decoder") ret := m.ctrl.Call(m, "Decoder")
ret0, _ := ret[0].(utils.Decoder) ret0, _ := ret[0].(http.Decoder)
return ret0 return ret0
} }
@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
} }
// Encoder mocks base method. // Encoder mocks base method.
func (m *MockAuthorizer) Encoder() utils.Encoder { func (m *MockAuthorizer) Encoder() http.Encoder {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encoder") ret := m.ctrl.Call(m, "Encoder")
ret0, _ := ret[0].(utils.Encoder) ret0, _ := ret[0].(http.Encoder)
return ret0 return ret0
} }
@ -105,6 +105,20 @@ func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
} }
// RequestObjectSupported mocks base method.
func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RequestObjectSupported indicates an expected call of RequestObjectSupported.
func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
}
// Signer mocks base method. // Signer mocks base method.
func (m *MockAuthorizer) Signer() op.Signer { func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -147,6 +147,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
} }
// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionAuthMethodPrivateKeyJWTSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// IntrospectionAuthMethodPrivateKeyJWTSupported indicates an expected call of IntrospectionAuthMethodPrivateKeyJWTSupported.
func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionAuthMethodPrivateKeyJWTSupported))
}
// IntrospectionEndpoint mocks base method. // IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -161,6 +175,20 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpoint)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpoint))
} }
// IntrospectionEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) IntrospectionEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// IntrospectionEndpointSigningAlgorithmsSupported indicates an expected call of IntrospectionEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
}
// Issuer mocks base method. // Issuer mocks base method.
func (m *MockConfiguration) Issuer() string { func (m *MockConfiguration) Issuer() string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -189,6 +217,76 @@ func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
} }
// RequestObjectSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) RequestObjectSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// RequestObjectSigningAlgorithmsSupported indicates an expected call of RequestObjectSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) RequestObjectSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSigningAlgorithmsSupported))
}
// RequestObjectSupported mocks base method.
func (m *MockConfiguration) RequestObjectSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RequestObjectSupported indicates an expected call of RequestObjectSupported.
func (mr *MockConfigurationMockRecorder) RequestObjectSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSupported))
}
// RevocationAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) RevocationAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationAuthMethodPrivateKeyJWTSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RevocationAuthMethodPrivateKeyJWTSupported indicates an expected call of RevocationAuthMethodPrivateKeyJWTSupported.
func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationAuthMethodPrivateKeyJWTSupported))
}
// RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// RevocationEndpoint indicates an expected call of RevocationEndpoint.
func (mr *MockConfigurationMockRecorder) RevocationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpoint))
}
// RevocationEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) RevocationEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// RevocationEndpointSigningAlgorithmsSupported indicates an expected call of RevocationEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) RevocationEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpointSigningAlgorithmsSupported))
}
// SupportedUILocales mocks base method. // SupportedUILocales mocks base method.
func (m *MockConfiguration) SupportedUILocales() []language.Tag { func (m *MockConfiguration) SupportedUILocales() []language.Tag {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -217,6 +315,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
} }
// TokenEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) TokenEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// TokenEndpointSigningAlgorithmsSupported indicates an expected call of TokenEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
}
// UserinfoEndpoint mocks base method. // UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -5,3 +5,4 @@ package mock
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client //go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration //go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer //go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer
//go:generate mockgen -package mock -destination ./key.mock.go github.com/caos/oidc/pkg/op KeyProvider

51
pkg/op/mock/key.mock.go Normal file
View file

@ -0,0 +1,51 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: KeyProvider)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockKeyProvider is a mock of KeyProvider interface.
type MockKeyProvider struct {
ctrl *gomock.Controller
recorder *MockKeyProviderMockRecorder
}
// MockKeyProviderMockRecorder is the mock recorder for MockKeyProvider.
type MockKeyProviderMockRecorder struct {
mock *MockKeyProvider
}
// NewMockKeyProvider creates a new mock instance.
func NewMockKeyProvider(ctrl *gomock.Controller) *MockKeyProvider {
mock := &MockKeyProvider{ctrl: ctrl}
mock.recorder = &MockKeyProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
return m.recorder
}
// GetKeySet mocks base method.
func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet.
func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0)
}

View file

@ -230,6 +230,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
} }
// RevokeToken mocks base method.
func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidc.Error)
return ret0
}
// RevokeToken indicates an expected call of RevokeToken.
func (mr *MockStorageMockRecorder) RevokeToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockStorage)(nil).RevokeToken), arg0, arg1, arg2, arg3)
}
// SaveAuthCode mocks base method. // SaveAuthCode mocks base method.
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error { func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -12,8 +12,8 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
const ( const (
@ -23,6 +23,7 @@ const (
defaultTokenEndpoint = "oauth/token" defaultTokenEndpoint = "oauth/token"
defaultIntrospectEndpoint = "oauth/introspect" defaultIntrospectEndpoint = "oauth/introspect"
defaultUserinfoEndpoint = "userinfo" defaultUserinfoEndpoint = "userinfo"
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session" defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys" defaultKeysEndpoint = "keys"
) )
@ -33,6 +34,7 @@ var (
Token: NewEndpoint(defaultTokenEndpoint), Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint), Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint), EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint), JwksURI: NewEndpoint(defaultKeysEndpoint),
} }
@ -41,8 +43,8 @@ var (
type OpenIDProvider interface { type OpenIDProvider interface {
Configuration Configuration
Storage() Storage Storage() Storage
Decoder() utils.Decoder Decoder() httphelper.Decoder
Encoder() utils.Encoder Encoder() httphelper.Encoder
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier() IDTokenHintVerifier
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier() AccessTokenVerifier
Crypto() Crypto Crypto() Crypto
@ -74,6 +76,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o))) router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o)))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o))) router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage())) router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
return router return router
@ -84,8 +87,10 @@ type Config struct {
CryptoKey [32]byte CryptoKey [32]byte
DefaultLogoutRedirectURI string DefaultLogoutRedirectURI string
CodeMethodS256 bool CodeMethodS256 bool
AuthMethodPost bool
AuthMethodPrivateKeyJWT bool AuthMethodPrivateKeyJWT bool
GrantTypeRefreshToken bool GrantTypeRefreshToken bool
RequestObjectSupported bool
SupportedUILocales []language.Tag SupportedUILocales []language.Tag
} }
@ -94,6 +99,7 @@ type endpoints struct {
Token Endpoint Token Endpoint
Introspection Endpoint Introspection Endpoint
Userinfo Endpoint Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint EndSession Endpoint
CheckSessionIframe Endpoint CheckSessionIframe Endpoint
JwksURI Endpoint JwksURI Endpoint
@ -148,7 +154,6 @@ type openidProvider struct {
decoder *schema.Decoder decoder *schema.Decoder
encoder *schema.Encoder encoder *schema.Encoder
interceptors []HttpInterceptor interceptors []HttpInterceptor
retry func(int) (bool, int)
timer <-chan time.Time timer <-chan time.Time
} }
@ -172,6 +177,10 @@ func (o *openidProvider) UserinfoEndpoint() Endpoint {
return o.endpoints.Userinfo return o.endpoints.Userinfo
} }
func (o *openidProvider) RevocationEndpoint() Endpoint {
return o.endpoints.Revocation
}
func (o *openidProvider) EndSessionEndpoint() Endpoint { func (o *openidProvider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession return o.endpoints.EndSession
} }
@ -181,7 +190,7 @@ func (o *openidProvider) KeysEndpoint() Endpoint {
} }
func (o *openidProvider) AuthMethodPostSupported() bool { func (o *openidProvider) AuthMethodPostSupported() bool {
return true //todo: config return o.config.AuthMethodPost
} }
func (o *openidProvider) CodeMethodS256Supported() bool { func (o *openidProvider) CodeMethodS256Supported() bool {
@ -192,6 +201,10 @@ func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool {
return o.config.AuthMethodPrivateKeyJWT return o.config.AuthMethodPrivateKeyJWT
} }
func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) GrantTypeRefreshTokenSupported() bool { func (o *openidProvider) GrantTypeRefreshTokenSupported() bool {
return o.config.GrantTypeRefreshToken return o.config.GrantTypeRefreshToken
} }
@ -204,6 +217,30 @@ func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool {
return true return true
} }
func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RequestObjectSupported() bool {
return o.config.RequestObjectSupported
}
func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) SupportedUILocales() []language.Tag { func (o *openidProvider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales return o.config.SupportedUILocales
} }
@ -212,11 +249,11 @@ func (o *openidProvider) Storage() Storage {
return o.storage return o.storage
} }
func (o *openidProvider) Decoder() utils.Decoder { func (o *openidProvider) Decoder() httphelper.Decoder {
return o.decoder return o.decoder
} }
func (o *openidProvider) Encoder() utils.Encoder { func (o *openidProvider) Encoder() httphelper.Encoder {
return o.encoder return o.encoder
} }
@ -332,6 +369,16 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Revocation = endpoint
return nil
}
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error { return func(o *openidProvider) error {
if err := endpoint.Validate(); err != nil { if err := endpoint.Validate(); err != nil {
@ -352,11 +399,12 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
} }
} }
func WithCustomEndpoints(auth, token, userInfo, endSession, keys Endpoint) Option { func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *openidProvider) error { return func(o *openidProvider) error {
o.endpoints.Authorization = auth o.endpoints.Authorization = auth
o.endpoints.Token = token o.endpoints.Token = token
o.endpoints.Userinfo = userInfo o.endpoints.Userinfo = userInfo
o.endpoints.Revocation = revocation
o.endpoints.EndSession = endSession o.endpoints.EndSession = endSession
o.endpoints.JwksURI = keys o.endpoints.JwksURI = keys
return nil return nil

View file

@ -5,7 +5,7 @@ import (
"errors" "errors"
"net/http" "net/http"
"github.com/caos/oidc/pkg/utils" httphelper "github.com/caos/oidc/pkg/http"
) )
type ProbesFn func(context.Context) error type ProbesFn func(context.Context) error
@ -49,7 +49,7 @@ func ReadyStorage(s Storage) ProbesFn {
} }
func ok(w http.ResponseWriter) { func ok(w http.ResponseWriter) {
utils.MarshalJSON(w, status{"ok"}) httphelper.MarshalJSON(w, status{"ok"})
} }
type status struct { type status struct {

View file

@ -4,12 +4,12 @@ import (
"context" "context"
"net/http" "net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type SessionEnder interface { type SessionEnder interface {
Decoder() utils.Decoder Decoder() httphelper.Decoder
Storage() Storage Storage() Storage
IDTokenHintVerifier() IDTokenHintVerifier IDTokenHintVerifier() IDTokenHintVerifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
@ -38,21 +38,21 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
} }
err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID) err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID)
if err != nil { if err != nil {
RequestError(w, r, ErrServerError("error terminating session")) RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
return return
} }
http.Redirect(w, r, session.RedirectURI, http.StatusFound) http.Redirect(w, r, session.RedirectURI, http.StatusFound)
} }
func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) { func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
req := new(oidc.EndSessionRequest) req := new(oidc.EndSessionRequest)
err = decoder.Decode(req, r.Form) err = decoder.Decode(req, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error decoding form") return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
return req, nil return req, nil
} }
@ -64,12 +64,12 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
} }
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier()) claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier())
if err != nil { if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid") return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
} }
session.UserID = claims.GetSubject() session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty()) session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil { if err != nil {
return nil, ErrServerError("") return nil, oidc.DefaultToServerError(err, "")
} }
if req.PostLogoutRedirectURI == "" { if req.PostLogoutRedirectURI == "" {
session.RedirectURI = ender.DefaultLogoutRedirectURI() session.RedirectURI = ender.DefaultLogoutRedirectURI()
@ -81,5 +81,5 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
return session, nil return session, nil
} }
} }
return nil, ErrInvalidRequest("post_logout_redirect_uri invalid") return nil, oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
} }

View file

@ -20,7 +20,8 @@ type AuthStorage interface {
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error)
TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (RefreshTokenRequest, error) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (RefreshTokenRequest, error)
TerminateSession(context.Context, string, string) error TerminateSession(ctx context.Context, userID string, clientID string) error
RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error
GetSigningKey(context.Context, chan<- jose.SigningKey) GetSigningKey(context.Context, chan<- jose.SigningKey)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error) GetKeySet(context.Context) (*jose.JSONWebKeySet, error)

View file

@ -4,8 +4,9 @@ import (
"context" "context"
"time" "time"
"github.com/caos/oidc/pkg/crypto"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/strings"
) )
type TokenCreator interface { type TokenCreator interface {
@ -64,7 +65,7 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag
func needsRefreshToken(tokenRequest TokenRequest, client Client) bool { func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
switch req := tokenRequest.(type) { switch req := tokenRequest.(type) {
case AuthRequest: case AuthRequest:
return utils.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
case RefreshTokenRequest: case RefreshTokenRequest:
return true return true
default: default:
@ -104,7 +105,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
} }
claims.SetPrivateClaims(privateClaims) claims.SetPrivateClaims(privateClaims)
} }
return utils.Sign(claims, signer.Signer()) return crypto.Sign(claims, signer.Signer())
} }
type IDTokenRequest interface { type IDTokenRequest interface {
@ -151,7 +152,7 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
claims.SetCodeHash(codeHash) claims.SetCodeHash(codeHash)
} }
return utils.Sign(claims, signer.Signer()) return crypto.Sign(claims, signer.Signer())
} }
func removeUserinfoScopes(scopes []string) []string { func removeUserinfoScopes(scopes []string) []string {
@ -167,5 +168,5 @@ func removeUserinfoScopes(scopes []string) []string {
newScopeList = append(newScopeList, scope) newScopeList = append(newScopeList, scope)
} }
} }
return scopes return newScopeList
} }

View file

@ -2,11 +2,10 @@ package op
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
//CodeExchange handles the OAuth 2.0 authorization_code grant, including //CodeExchange handles the OAuth 2.0 authorization_code grant, including
@ -17,7 +16,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err) RequestError(w, r, err)
} }
if tokenReq.Code == "" { if tokenReq.Code == "" {
RequestError(w, r, ErrInvalidRequest("code missing")) RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"))
return return
} }
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
@ -30,11 +29,11 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
utils.MarshalJSON(w, resp) httphelper.MarshalJSON(w, resp)
} }
//ParseAccessTokenRequest parsed the http request into a oidc.AccessTokenRequest //ParseAccessTokenRequest parsed the http request into a oidc.AccessTokenRequest
func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) { func ParseAccessTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AccessTokenRequest, error) {
request := new(oidc.AccessTokenRequest) request := new(oidc.AccessTokenRequest)
err := ParseAuthenticatedTokenRequest(r, decoder, request) err := ParseAuthenticatedTokenRequest(r, decoder, request)
if err != nil { if err != nil {
@ -51,13 +50,13 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR
return nil, nil, err return nil, nil, err
} }
if client.GetID() != authReq.GetClientID() { if client.GetID() != authReq.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code") return nil, nil, oidc.ErrInvalidGrant()
} }
if !ValidateGrantType(client, oidc.GrantTypeCode) { if !ValidateGrantType(client, oidc.GrantTypeCode) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
if tokenReq.RedirectURI != authReq.GetRedirectURI() { if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, nil, ErrInvalidRequest("redirect_uri does not correspond") return nil, nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
} }
return authReq, client, nil return authReq, client, nil
} }
@ -68,7 +67,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
return nil, nil, errors.New("auth_method private_key_jwt not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
} }
client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger) client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
if err != nil { if err != nil {
@ -79,10 +78,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
} }
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, oidc.ErrInvalidClient().WithParent(err)
} }
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant") return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
} }
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
@ -93,9 +92,12 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return request, client, err return request, client, err
} }
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
} }
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, err
}
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code) request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err return request, client, err
} }
@ -104,7 +106,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) { func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) {
authReq, err := storage.AuthRequestByCode(ctx, code) authReq, err := storage.AuthRequestByCode(ctx, code)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid code") return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err)
} }
return authReq, nil return authReq, nil
} }

View file

@ -3,28 +3,9 @@ package op
import ( import (
"errors" "errors"
"net/http" "net/http"
"github.com/caos/oidc/pkg/oidc"
) )
//TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange") //TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r) RequestError(w, r, errors.New("unimplemented"))
if err != nil {
RequestError(w, r, err)
return
}
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
if err != nil {
RequestError(w, r, err)
return
}
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
return errors.New("Unimplemented") //TODO: impl
} }

View file

@ -5,12 +5,12 @@ import (
"net/http" "net/http"
"net/url" "net/url"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type Introspector interface { type Introspector interface {
Decoder() utils.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier() AccessTokenVerifier
@ -36,16 +36,16 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
} }
tokenID, subject, ok := getTokenIDAndSubject(r.Context(), introspector, token) tokenID, subject, ok := getTokenIDAndSubject(r.Context(), introspector, token)
if !ok { if !ok {
utils.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return return
} }
err = introspector.Storage().SetIntrospectionFromToken(r.Context(), response, tokenID, subject, clientID) err = introspector.Storage().SetIntrospectionFromToken(r.Context(), response, tokenID, subject, clientID)
if err != nil { if err != nil {
utils.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return return
} }
response.SetActive(true) response.SetActive(true)
utils.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
} }
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) { func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {

View file

@ -5,8 +5,8 @@ import (
"net/http" "net/http"
"time" "time"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type JWTAuthorizationGrantExchanger interface { type JWTAuthorizationGrantExchanger interface {
@ -37,18 +37,18 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
utils.MarshalJSON(w, resp) httphelper.MarshalJSON(w, resp)
} }
func ParseJWTProfileGrantRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) { func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
tokenReq := new(oidc.JWTProfileGrantRequest) tokenReq := new(oidc.JWTProfileGrantRequest)
err = decoder.Decode(tokenReq, r.Form) err = decoder.Decode(tokenReq, r.Form)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error decoding form") return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
return tokenReq, nil return tokenReq, nil
} }
@ -74,6 +74,6 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
//ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest //ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest
// //
//deprecated: use ParseJWTProfileGrantRequest //deprecated: use ParseJWTProfileGrantRequest
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) { func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
return ParseJWTProfileGrantRequest(r, decoder) return ParseJWTProfileGrantRequest(r, decoder)
} }

View file

@ -6,8 +6,9 @@ import (
"net/http" "net/http"
"time" "time"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/strings"
) )
type RefreshTokenRequest interface { type RefreshTokenRequest interface {
@ -37,11 +38,11 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch
RequestError(w, r, err) RequestError(w, r, err)
return return
} }
utils.MarshalJSON(w, resp) httphelper.MarshalJSON(w, resp)
} }
//ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest //ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest
func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.RefreshTokenRequest, error) { func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.RefreshTokenRequest, error) {
request := new(oidc.RefreshTokenRequest) request := new(oidc.RefreshTokenRequest)
err := ParseAuthenticatedTokenRequest(r, decoder, request) err := ParseAuthenticatedTokenRequest(r, decoder, request)
if err != nil { if err != nil {
@ -54,14 +55,14 @@ func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.Ref
//and returns the data representing the original auth request corresponding to the refresh_token //and returns the data representing the original auth request corresponding to the refresh_token
func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) { func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) {
if tokenReq.RefreshToken == "" { if tokenReq.RefreshToken == "" {
return nil, nil, ErrInvalidRequest("code missing") return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing")
} }
request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger) request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if client.GetID() != request.GetClientID() { if client.GetID() != request.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code") return nil, nil, oidc.ErrInvalidGrant()
} }
if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil { if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil {
return nil, nil, err return nil, nil, err
@ -77,15 +78,15 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok
return nil return nil
} }
for _, scope := range requestedScopes { for _, scope := range requestedScopes {
if !utils.Contains(authRequest.GetScopes(), scope) { if !strings.Contains(authRequest.GetScopes(), scope) {
return errors.New("invalid_scope") return oidc.ErrInvalidScope()
} }
} }
authRequest.SetCurrentScopes(requestedScopes) authRequest.SetCurrentScopes(requestedScopes)
return nil return nil
} }
//AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered. //AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered.
//It than returns the data representing the original auth request corresponding to the refresh_token //It than returns the data representing the original auth request corresponding to the refresh_token
func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) { func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) {
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
@ -98,7 +99,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err return nil, nil, err
} }
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err return request, client, err
@ -108,17 +109,17 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err return nil, nil, err
} }
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) { if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant") return nil, nil, oidc.ErrUnauthorizedClient()
} }
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant") return nil, nil, oidc.ErrInvalidClient()
} }
if client.AuthMethod() == oidc.AuthMethodNone { if client.AuthMethod() == oidc.AuthMethodNone {
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken) request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err return request, client, err
} }
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported") return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
} }
if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil { if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil {
return nil, nil, err return nil, nil, err
@ -132,7 +133,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) {
request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken) request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken)
if err != nil { if err != nil {
return nil, ErrInvalidRequest("invalid refreshToken") return nil, oidc.ErrInvalidGrant().WithParent(err)
} }
return request, nil return request, nil
} }

View file

@ -5,14 +5,14 @@ import (
"net/http" "net/http"
"net/url" "net/url"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type Exchanger interface { type Exchanger interface {
Issuer() string Issuer() string
Storage() Storage Storage() Storage
Decoder() utils.Decoder Decoder() httphelper.Decoder
Signer() Signer Signer() Signer
Crypto() Crypto Crypto() Crypto
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
@ -24,7 +24,8 @@ type Exchanger interface {
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
switch r.FormValue("grant_type") { grantType := r.FormValue("grant_type")
switch grantType {
case string(oidc.GrantTypeCode): case string(oidc.GrantTypeCode):
CodeExchange(w, r, exchanger) CodeExchange(w, r, exchanger)
return return
@ -44,14 +45,14 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque
return return
} }
case "": case "":
RequestError(w, r, ErrInvalidRequest("grant_type missing")) RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return return
} }
RequestError(w, r, ErrInvalidRequest("grant_type not supported")) RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType))
} }
} }
//authenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest //AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest
//it is implemented by oidc.AuthRequest and oidc.RefreshTokenRequest //it is implemented by oidc.AuthRequest and oidc.RefreshTokenRequest
type AuthenticatedTokenRequest interface { type AuthenticatedTokenRequest interface {
SetClientID(string) SetClientID(string)
@ -60,48 +61,49 @@ type AuthenticatedTokenRequest interface {
//ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either //ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either
//HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface //HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface
func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, request AuthenticatedTokenRequest) error { func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, request AuthenticatedTokenRequest) error {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return ErrInvalidRequest("error parsing form") return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
} }
err = decoder.Decode(request, r.Form) err = decoder.Decode(request, r.Form)
if err != nil { if err != nil {
return ErrInvalidRequest("error decoding form") return oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
} }
clientID, clientSecret, ok := r.BasicAuth() clientID, clientSecret, ok := r.BasicAuth()
if ok { if !ok {
return nil
}
clientID, err = url.QueryUnescape(clientID) clientID, err = url.QueryUnescape(clientID)
if err != nil { if err != nil {
return ErrInvalidRequest("invalid basic auth header") return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
} }
clientSecret, err = url.QueryUnescape(clientSecret) clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil { if err != nil {
return ErrInvalidRequest("invalid basic auth header") return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
} }
request.SetClientID(clientID) request.SetClientID(clientID)
request.SetClientSecret(clientSecret) request.SetClientSecret(clientSecret)
}
return nil return nil
} }
//AuthorizeRefreshClientByClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST) //AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST)
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error { func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error {
err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
if err != nil { if err != nil {
return err //TODO: wrap? return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err)
} }
return nil return nil
} }
//AuthorizeCodeClientByCodeChallenge authorizes a client by validating the code_verifier against the previously sent //AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
//code_challenge of the auth request (PKCE) //code_challenge of the auth request (PKCE)
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error {
if tokenReq.CodeVerifier == "" { if tokenReq.CodeVerifier == "" {
return ErrInvalidRequest("code_challenge required") return oidc.ErrInvalidRequest().WithDescription("code_challenge required")
} }
if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) {
return ErrInvalidRequest("code_challenge invalid") return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
} }
return nil return nil
} }
@ -118,7 +120,7 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang
return nil, err return nil, err
} }
if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT { if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT {
return nil, ErrInvalidRequest("invalid_client") return nil, oidc.ErrInvalidClient()
} }
return client, nil return client, nil
} }

136
pkg/op/token_revocation.go Normal file
View file

@ -0,0 +1,136 @@
package op
import (
"context"
"net/http"
"net/url"
"strings"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
)
type Revoker interface {
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
AuthMethodPrivateKeyJWTSupported() bool
AuthMethodPostSupported() bool
}
type RevokerJWTProfile interface {
Revoker
JWTProfileVerifier() JWTProfileVerifier
}
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Revoke(w, r, revoker)
}
}
func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) {
token, _, clientID, err := ParseTokenRevocationRequest(r, revoker)
if err != nil {
RevocationRequestError(w, r, err)
return
}
tokenID, subject, ok := getTokenIDAndSubjectForRevocation(r.Context(), revoker, token)
if ok {
token = tokenID
}
if err := revoker.Storage().RevokeToken(r.Context(), token, subject, clientID); err != nil {
RevocationRequestError(w, r, err)
return
}
httphelper.MarshalJSON(w, nil)
}
func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) {
err = r.ParseForm()
if err != nil {
return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err)
}
req := new(struct {
oidc.RevocationRequest
oidc.ClientAssertionParams //for auth_method private_key_jwt
ClientID string `schema:"client_id"` //for auth_method none and post
ClientSecret string `schema:"client_secret"` //for auth_method post
})
err = revoker.Decoder().Decode(req, r.Form)
if err != nil {
return "", "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
if req.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
revokerJWTProfile, ok := revoker.(RevokerJWTProfile)
if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier())
if err == nil {
return req.Token, req.TokenTypeHint, profile.Issuer, nil
}
return "", "", "", err
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
if err = AuthorizeClientIDSecret(r.Context(), clientID, clientSecret, revoker.Storage()); err != nil {
return "", "", "", err
}
return req.Token, req.TokenTypeHint, clientID, nil
}
if req.ClientID == "" {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
}
client, err := revoker.Storage().GetClientByClientID(r.Context(), req.ClientID)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithParent(err)
}
if req.ClientSecret == "" {
if client.AuthMethod() != oidc.AuthMethodNone {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
}
return req.Token, req.TokenTypeHint, req.ClientID, nil
}
if client.AuthMethod() == oidc.AuthMethodPost && !revoker.AuthMethodPostSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
if err = AuthorizeClientIDSecret(r.Context(), req.ClientID, req.ClientSecret, revoker.Storage()); err != nil {
return "", "", "", err
}
return req.Token, req.TokenTypeHint, req.ClientID, nil
}
func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient {
status = 401
}
httphelper.MarshalJSONWithStatus(w, e, status)
}
func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", false
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
}

View file

@ -6,12 +6,12 @@ import (
"net/http" "net/http"
"strings" "strings"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
) )
type UserinfoProvider interface { type UserinfoProvider interface {
Decoder() utils.Decoder Decoder() httphelper.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
AccessTokenVerifier() AccessTokenVerifier AccessTokenVerifier() AccessTokenVerifier
@ -37,14 +37,13 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
info := oidc.NewUserInfo() info := oidc.NewUserInfo()
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin")) err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
if err != nil { if err != nil {
w.WriteHeader(http.StatusForbidden) httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
utils.MarshalJSON(w, err)
return return
} }
utils.MarshalJSON(w, info) httphelper.MarshalJSON(w, info)
} }
func ParseUserinfoRequest(r *http.Request, decoder utils.Decoder) (string, error) { func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) {
accessToken, err := getAccessToken(r) accessToken, err := getAccessToken(r)
if err == nil { if err == nil {
return accessToken, nil return accessToken, nil

View file

@ -49,7 +49,7 @@ func (i *accessTokenVerifier) KeySet() oidc.KeySet {
} }
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier { func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
verifier := &idTokenHintVerifier{ verifier := &accessTokenVerifier{
issuer: issuer, issuer: issuer,
keySet: keySet, keySet: keySet,
} }

View file

@ -1,4 +1,4 @@
package utils package strings
func Contains(list []string, needle string) bool { func Contains(list []string, needle string) bool {
for _, item := range list { for _, item := range list {

View file

@ -1,4 +1,4 @@
package utils package strings
import "testing" import "testing"