feat: Token Revocation, Request Object and OP Certification (#130)

FEATURES (and FIXES):
- support OAuth 2.0 Token Revocation [RFC 7009](https://datatracker.ietf.org/doc/html/rfc7009)
- handle request object using `request` parameter [OIDC Core 1.0 Request Object](https://openid.net/specs/openid-connect-core-1_0.html#RequestObject)
- handle response mode
- added some information to the discovery endpoint:
  - revocation_endpoint (added with token revocation) 
  - revocation_endpoint_auth_methods_supported (added with token revocation)
  - revocation_endpoint_auth_signing_alg_values_supported (added with token revocation)
  - token_endpoint_auth_signing_alg_values_supported (was missing)
  - introspection_endpoint_auth_signing_alg_values_supported (was missing)
  - request_object_signing_alg_values_supported (added with request object)
  - request_parameter_supported (added with request object)
 - fixed `removeUserinfoScopes ` now returns the scopes without "userinfo" scopes (profile, email, phone, addedd) [source diff](https://github.com/caos/oidc/pull/130/files#diff-fad50c8c0f065d4dbc49d6c6a38f09c992c8f5d651a479ba00e31b500543559eL170-R171)
- improved error handling (pkg/oidc/error.go) and fixed some wrong OAuth errors (e.g. `invalid_grant` instead of `invalid_request`)
- improved MarshalJSON and added MarshalJSONWithStatus
- removed deprecated PEM decryption from `BytesToPrivateKey`  [source diff](https://github.com/caos/oidc/pull/130/files#diff-fe246e428e399ccff599627c71764de51387b60b4df84c67de3febd0954e859bL11-L19)
- NewAccessTokenVerifier now uses correct (internal) `accessTokenVerifier` [source diff](https://github.com/caos/oidc/pull/130/files#diff-3a01c7500ead8f35448456ef231c7c22f8d291710936cac91de5edeef52ffc72L52-R52)

BREAKING CHANGE:
- move functions from `utils` package into separate packages
- added various methods to the (OP) `Configuration` interface [source diff](https://github.com/caos/oidc/pull/130/files#diff-2538e0dfc772fdc37f057aecd6fcc2943f516c24e8be794cce0e368a26d20a82R19-R32)
- added revocationEndpoint to `WithCustomEndpoints ` [source diff](https://github.com/caos/oidc/pull/130/files#diff-19ae13a743eb7cebbb96492798b1bec556673eb6236b1387e38d722900bae1c3L355-R391)
- remove unnecessary context parameter from JWTProfileExchange [source diff](https://github.com/caos/oidc/pull/130/files#diff-4ed8f6affa4a9631fa8a034b3d5752fbb6a819107141aae00029014e950f7b4cL14)
This commit is contained in:
Livio Amstutz 2021-11-02 13:21:35 +01:00 committed by GitHub
parent 763d3334e7
commit eb10752e48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
63 changed files with 1738 additions and 624 deletions

View file

@ -2,7 +2,6 @@ package op
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
@ -11,8 +10,9 @@ import (
"github.com/gorilla/mux"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
str "github.com/caos/oidc/pkg/strings"
)
type AuthRequest interface {
@ -26,6 +26,7 @@ type AuthRequest interface {
GetNonce() string
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetResponseMode() oidc.ResponseMode
GetScopes() []string
GetState() string
GetSubject() string
@ -34,16 +35,17 @@ type AuthRequest interface {
type Authorizer interface {
Storage() Storage
Decoder() utils.Decoder
Encoder() utils.Encoder
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
Signer() Signer
IDTokenHintVerifier() IDTokenHintVerifier
Crypto() Crypto
Issuer() string
RequestObjectSupported() bool
}
//AuthorizeValidator is an extension of Authorizer interface
//implementing it's own validation mechanism for the auth request
//implementing its own validation mechanism for the auth request
type AuthorizeValidator interface {
Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error)
@ -69,6 +71,13 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(r.Context(), authReq, authorizer.Storage(), authorizer.Issuer())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
}
validation := ValidateAuthRequest
if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
@ -78,33 +87,114 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
if authReq.RequestParam != "" {
AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder())
return
}
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, err, authorizer.Encoder())
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return
}
RedirectToLogin(req.GetID(), client, w, r)
}
//ParseAuthorizeRequest parsed the http request into a oidc.AuthRequest
func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) {
//ParseAuthorizeRequest parsed the http request into an oidc.AuthRequest
func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AuthRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("cannot parse form")
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
}
authReq := new(oidc.AuthRequest)
err = decoder.Decode(authReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err))
return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse auth request").WithParent(err)
}
return authReq, nil
}
//ParseRequestObject parse the `request` parameter, validates the token including the signature
//and copies the token claims into the auth request
func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) {
requestObject := new(oidc.RequestObject)
payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
if err != nil {
return nil, err
}
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
return authReq, oidc.ErrInvalidRequest()
}
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
return authReq, oidc.ErrInvalidRequest()
}
if requestObject.Issuer != requestObject.ClientID {
return authReq, oidc.ErrInvalidRequest()
}
if !str.Contains(requestObject.Audience, issuer) {
return authReq, oidc.ErrInvalidRequest()
}
keySet := &jwtProfileKeySet{storage, requestObject.Issuer}
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
return authReq, err
}
CopyRequestObjectToAuthRequest(authReq, requestObject)
return authReq, nil
}
//CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
//and clears the `RequestParam` of the auth request
func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) {
if str.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 {
authReq.Scopes = requestObject.Scopes
}
if requestObject.RedirectURI != "" {
authReq.RedirectURI = requestObject.RedirectURI
}
if requestObject.State != "" {
authReq.State = requestObject.State
}
if requestObject.ResponseMode != "" {
authReq.ResponseMode = requestObject.ResponseMode
}
if requestObject.Nonce != "" {
authReq.Nonce = requestObject.Nonce
}
if requestObject.Display != "" {
authReq.Display = requestObject.Display
}
if len(requestObject.Prompt) > 0 {
authReq.Prompt = requestObject.Prompt
}
if requestObject.MaxAge != nil {
authReq.MaxAge = requestObject.MaxAge
}
if len(requestObject.UILocales) > 0 {
authReq.UILocales = requestObject.UILocales
}
if requestObject.IDTokenHint != "" {
authReq.IDTokenHint = requestObject.IDTokenHint
}
if requestObject.LoginHint != "" {
authReq.LoginHint = requestObject.LoginHint
}
if len(requestObject.ACRValues) > 0 {
authReq.ACRValues = requestObject.ACRValues
}
if requestObject.CodeChallenge != "" {
authReq.CodeChallenge = requestObject.CodeChallenge
}
if requestObject.CodeChallengeMethod != "" {
authReq.CodeChallengeMethod = requestObject.CodeChallengeMethod
}
authReq.RequestParam = ""
}
//ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
@ -113,7 +203,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
}
client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", ErrServerError(err.Error())
return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
}
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
@ -132,7 +222,7 @@ func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage
func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
for _, prompt := range prompts {
if prompt == oidc.PromptNone && len(prompts) > 1 {
return nil, ErrInvalidRequest("The prompt parameter `none` must only be used as a single value")
return nil, oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value")
}
if prompt == oidc.PromptLogin {
maxAge = oidc.NewMaxAge(0)
@ -144,7 +234,9 @@ func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error)
//ValidateAuthReqScopes validates the passed scopes
func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
if len(scopes) == 0 {
return nil, ErrInvalidRequest("The scope of your request is missing. Please ensure some scopes are requested. If you have any questions, you may contact the administrator of the application.")
return nil, oidc.ErrInvalidRequest().
WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
"If you have any questions, you may contact the administrator of the application.")
}
openID := false
for i := len(scopes) - 1; i >= 0; i-- {
@ -165,7 +257,9 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
}
}
if !openID {
return nil, ErrInvalidRequest("The scope openid is missing in your request. Please ensure the scope openid is added to the request. If you have any questions, you may contact the administrator of the application.")
return nil, oidc.ErrInvalidScope().WithDescription("The scope openid is missing in your request. " +
"Please ensure the scope openid is added to the request. " +
"If you have any questions, you may contact the administrator of the application.")
}
return scopes, nil
@ -174,19 +268,23 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
//ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error {
if uri == "" {
return ErrInvalidRequestRedirectURI("The redirect_uri is missing in the request. Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " +
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
}
if strings.HasPrefix(uri, "https://") {
if !utils.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
if !str.Contains(client.RedirectURIs(), uri) {
return oidc.ErrInvalidRequestRedirectURI().
WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
}
return nil
}
if client.ApplicationType() == ApplicationTypeNative {
return validateAuthReqRedirectURINative(client, uri, responseType)
}
if !utils.Contains(client.RedirectURIs(), uri) {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
if !str.Contains(client.RedirectURIs(), uri) {
return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
}
if strings.HasPrefix(uri, "http://") {
if client.DevMode() {
@ -195,23 +293,27 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) {
return nil
}
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
}
return ErrInvalidRequest("This client's redirect_uri is using a custom schema and is not allowed. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is using a custom schema and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
}
//ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error {
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://")
if utils.Contains(client.RedirectURIs(), uri) {
if str.Contains(client.RedirectURIs(), uri) {
if isLoopback || isCustomSchema {
return nil
}
return ErrInvalidRequest("This client's redirect_uri is http and is not allowed. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
"If you have any questions, you may contact the administrator of the application.")
}
if !isLoopback {
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
}
for _, uri := range client.RedirectURIs() {
redirectURI, ok := HTTPLoopbackOrLocalhost(uri)
@ -219,7 +321,8 @@ func validateAuthReqRedirectURINative(client Client, uri string, responseType oi
return nil
}
}
return ErrInvalidRequestRedirectURI("The requested redirect_uri is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration." +
" If you have any questions, you may contact the administrator of the application.")
}
func equalURI(url1, url2 *url.URL) bool {
@ -241,10 +344,12 @@ func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) {
//ValidateAuthReqResponseType validates the passed response_type to the registered response types
func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error {
if responseType == "" {
return ErrInvalidRequest("The response type is missing in your request. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrInvalidRequest().WithDescription("The response type is missing in your request. " +
"If you have any questions, you may contact the administrator of the application.")
}
if !ContainsResponseType(client.ResponseTypes(), responseType) {
return ErrInvalidRequest("The requested response type is missing in the client configuration. If you have any questions, you may contact the administrator of the application.")
return oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
"If you have any questions, you may contact the administrator of the application.")
}
return nil
}
@ -257,7 +362,8 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
}
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier)
if err != nil {
return "", ErrInvalidRequest("The id_token_hint is invalid. If you have any questions, you may contact the administrator of the application.")
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
"If you have any questions, you may contact the administrator of the application.")
}
return claims.GetSubject(), nil
}
@ -279,7 +385,9 @@ func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Author
return
}
if !authReq.Done() {
AuthRequestError(w, r, authReq, ErrInteractionRequired("Unfortunately, the user may is not logged in and/or additional interaction is required."), authorizer.Encoder())
AuthRequestError(w, r, authReq,
oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
authorizer.Encoder())
return
}
AuthResponse(authReq, authorizer, w, r)
@ -306,9 +414,17 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
if authReq.GetState() != "" {
callback = callback + "&state=" + authReq.GetState()
codeResponse := struct {
code string
state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
http.Redirect(w, r, callback, http.StatusFound)
}
@ -321,12 +437,11 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
http.Redirect(w, r, callback, http.StatusFound)
}
@ -346,3 +461,22 @@ func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Sto
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
return crypto.Encrypt(authReq.GetID())
}
//AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values
//depending on the response_mode and response_type
func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response interface{}, encoder httphelper.Encoder) (string, error) {
params, err := httphelper.URLEncodeResponse(response, encoder)
if err != nil {
return "", oidc.ErrServerError().WithParent(err)
}
if responseMode == oidc.ResponseModeQuery {
return redirectURI + "?" + params, nil
}
if responseMode == oidc.ResponseModeFragment {
return redirectURI + "#" + params, nil
}
if responseType == "" || responseType == oidc.ResponseTypeCode {
return redirectURI + "?" + params, nil
}
return redirectURI + "#" + params, nil
}

View file

@ -1,6 +1,8 @@
package op_test
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
@ -11,10 +13,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock"
"github.com/caos/oidc/pkg/utils"
)
//
@ -75,7 +77,7 @@ import (
func TestParseAuthorizeRequest(t *testing.T) {
type args struct {
r *http.Request
decoder utils.Decoder
decoder httphelper.Decoder
}
type res struct {
want *oidc.AuthRequest
@ -101,7 +103,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
"decoding error",
args{
&http.Request{URL: &url.URL{RawQuery: "unknown=value"}},
func() utils.Decoder {
func() httphelper.Decoder {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(false)
return decoder
@ -116,7 +118,7 @@ func TestParseAuthorizeRequest(t *testing.T) {
"parsing ok",
args{
&http.Request{URL: &url.URL{RawQuery: "scope=openid"}},
func() utils.Decoder {
func() httphelper.Decoder {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(false)
return decoder
@ -150,44 +152,138 @@ func TestValidateAuthRequest(t *testing.T) {
tests := []struct {
name string
args args
wantErr bool
wantErr error
}{
//TODO:
// {
// "oauth2 spec"
// }
{
"scope missing fails",
args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil},
true,
oidc.ErrInvalidRequest(),
},
{
"scope openid missing fails",
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true,
oidc.ErrInvalidScope(),
},
{
"response_type missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
true,
oidc.ErrInvalidRequest(),
},
{
"client_id missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
true,
oidc.ErrInvalidRequest(),
},
{
"redirect_uri missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil},
true,
oidc.ErrInvalidRequest(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr)
_, err := op.ValidateAuthRequest(context.TODO(), tt.args.authRequest, tt.args.storage, tt.args.verifier)
if tt.wantErr == nil && err != nil {
t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.wantErr)
}
})
}
}
func TestValidateAuthReqPrompt(t *testing.T) {
type args struct {
prompts []string
maxAge *uint
}
type res struct {
maxAge *uint
err error
}
tests := []struct {
name string
args args
res res
}{
{
"no prompts and maxAge, ok",
args{
nil,
nil,
},
res{
nil,
nil,
},
},
{
"no prompts but maxAge, ok",
args{
nil,
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(10),
nil,
},
},
{
"prompt none, ok",
args{
[]string{"none"},
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(10),
nil,
},
},
{
"prompt none with others, err",
args{
[]string{"none", "login"},
oidc.NewMaxAge(10),
},
res{
nil,
oidc.ErrInvalidRequest(),
},
},
{
"prompt login, ok",
args{
[]string{"login"},
nil,
},
res{
oidc.NewMaxAge(0),
nil,
},
},
{
"prompt login with maxAge, ok",
args{
[]string{"login"},
oidc.NewMaxAge(10),
},
res{
oidc.NewMaxAge(0),
nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
maxAge, err := op.ValidateAuthReqPrompt(tt.args.prompts, tt.args.maxAge)
if tt.res.err == nil && err != nil {
t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.res.err != nil && !errors.Is(err, tt.res.err) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
}
assert.Equal(t, tt.res.maxAge, maxAge)
})
}
}
@ -465,6 +561,80 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
}
}
func TestLoopbackOrLocalhost(t *testing.T) {
type args struct {
url string
}
tests := []struct {
name string
args args
want bool
}{
{
"not parsable, false",
args{url: string('\n')},
false,
},
{
"not http, false",
args{url: "localhost/test"},
false,
},
{
"not http, false",
args{url: "http://localhost.com/test"},
false,
},
{
"v4 no port ok",
args{url: "http://127.0.0.1/test"},
true,
},
{
"v6 short no port ok",
args{url: "http://[::1]/test"},
true,
},
{
"v6 long no port ok",
args{url: "http://[0:0:0:0:0:0:0:1]/test"},
true,
},
{
"locahost no port ok",
args{url: "http://localhost/test"},
true,
},
{
"v4 with port ok",
args{url: "http://127.0.0.1:4200/test"},
true,
},
{
"v6 short with port ok",
args{url: "http://[::1]:4200/test"},
true,
},
{
"v6 long with port ok",
args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
true,
},
{
"localhost with port ok",
args{url: "http://localhost:4200/test"},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want {
t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want)
}
})
}
}
func TestValidateAuthReqResponseType(t *testing.T) {
type args struct {
responseType oidc.ResponseType
@ -534,100 +704,122 @@ func TestRedirectToLogin(t *testing.T) {
}
}
func TestAuthorizeCallback(t *testing.T) {
func TestAuthResponseURL(t *testing.T) {
type args struct {
w http.ResponseWriter
r *http.Request
authorizer op.Authorizer
redirectURI string
responseType oidc.ResponseType
responseMode oidc.ResponseMode
response interface{}
encoder httphelper.Encoder
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
})
}
}
func TestAuthResponse(t *testing.T) {
type args struct {
authReq op.AuthRequest
authorizer op.Authorizer
w http.ResponseWriter
r *http.Request
}
tests := []struct {
name string
args args
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
})
}
}
func Test_LoopbackOrLocalhost(t *testing.T) {
type args struct {
type res struct {
url string
err error
}
tests := []struct {
name string
args args
want bool
res res
}{
{
"v4 no port ok",
args{url: "http://127.0.0.1/test"},
true,
"encoding error",
args{
"uri",
oidc.ResponseTypeCode,
"",
map[string]interface{}{"test": "test"},
&mockEncoder{
errors.New("error encoding"),
},
},
res{
"",
oidc.ErrServerError(),
},
},
{
"v6 short no port ok",
args{url: "http://[::1]/test"},
true,
"response mode query",
args{
"uri",
oidc.ResponseTypeIDToken,
oidc.ResponseModeQuery,
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri?test=test",
nil,
},
},
{
"v6 long no port ok",
args{url: "http://[0:0:0:0:0:0:0:1]/test"},
true,
"response mode fragment",
args{
"uri",
oidc.ResponseTypeCode,
oidc.ResponseModeFragment,
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri#test=test",
nil,
},
},
{
"locahost no port ok",
args{url: "http://localhost/test"},
true,
"response type code",
args{
"uri",
oidc.ResponseTypeCode,
"",
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri?test=test",
nil,
},
},
{
"v4 with port ok",
args{url: "http://127.0.0.1:4200/test"},
true,
},
{
"v6 short with port ok",
args{url: "http://[::1]:4200/test"},
true,
},
{
"v6 long with port ok",
args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
true,
},
{
"localhost with port ok",
args{url: "http://localhost:4200/test"},
true,
"response type id token",
args{
"uri",
oidc.ResponseTypeIDToken,
"",
map[string][]string{"test": {"test"}},
&mockEncoder{},
},
res{
"uri#test=test",
nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want {
t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want)
got, err := op.AuthResponseURL(tt.args.redirectURI, tt.args.responseType, tt.args.responseMode, tt.args.response, tt.args.encoder)
if tt.res.err == nil && err != nil {
t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
}
if tt.res.err != nil && !errors.Is(err, tt.res.err) {
t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
}
if got != tt.res.url {
t.Errorf("AuthResponseURL() got = %v, want %v", got, tt.res.url)
}
})
}
}
type mockEncoder struct {
err error
}
func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error {
if m.err != nil {
return m.err
}
for s, strings := range src.(map[string][]string) {
dst[s] = strings
}
return nil
}

View file

@ -16,15 +16,23 @@ type Configuration interface {
TokenEndpoint() Endpoint
IntrospectionEndpoint() Endpoint
UserinfoEndpoint() Endpoint
RevocationEndpoint() Endpoint
EndSessionEndpoint() Endpoint
KeysEndpoint() Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool
AuthMethodPrivateKeyJWTSupported() bool
TokenEndpointSigningAlgorithmsSupported() []string
GrantTypeRefreshTokenSupported() bool
GrantTypeTokenExchangeSupported() bool
GrantTypeJWTAuthorizationSupported() bool
IntrospectionAuthMethodPrivateKeyJWTSupported() bool
IntrospectionEndpointSigningAlgorithmsSupported() []string
RevocationAuthMethodPrivateKeyJWTSupported() bool
RevocationEndpointSigningAlgorithmsSupported() []string
RequestObjectSupported() bool
RequestObjectSigningAlgorithmsSupported() []string
SupportedUILocales() []language.Tag
}

View file

@ -61,6 +61,7 @@ func TestValidateIssuer(t *testing.T) {
},
}
//ensure env is not set
//nolint:errcheck
os.Unsetenv(OidcDevMode)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -86,6 +87,7 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
false,
},
}
//nolint:errcheck
os.Setenv(OidcDevMode, "true")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -1,7 +1,7 @@
package op
import (
"github.com/caos/oidc/pkg/utils"
"github.com/caos/oidc/pkg/crypto"
)
type Crypto interface {
@ -18,9 +18,9 @@ func NewAESCrypto(key [32]byte) Crypto {
}
func (c *aesCrypto) Encrypt(s string) (string, error) {
return utils.EncryptAES(s, c.key)
return crypto.EncryptAES(s, c.key)
}
func (c *aesCrypto) Decrypt(s string) (string, error) {
return utils.DecryptAES(s, c.key)
return crypto.DecryptAES(s, c.key)
}

View file

@ -3,8 +3,8 @@ package op
import (
"net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http.Request) {
@ -14,28 +14,35 @@ func discoveryHandler(c Configuration, s Signer) func(http.ResponseWriter, *http
}
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
utils.MarshalJSON(w, config)
httphelper.MarshalJSON(w, config)
}
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
return &oidc.DiscoveryConfiguration{
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c),
ResponseTypesSupported: ResponseTypes(c),
GrantTypesSupported: GrantTypes(c),
SubjectTypesSupported: SubjectTypes(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
ClaimsSupported: SupportedClaims(c),
CodeChallengeMethodsSupported: CodeChallengeMethods(c),
UILocalesSupported: c.SupportedUILocales(),
Issuer: c.Issuer(),
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
IntrospectionEndpoint: c.IntrospectionEndpoint().Absolute(c.Issuer()),
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
RevocationEndpoint: c.RevocationEndpoint().Absolute(c.Issuer()),
EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
ScopesSupported: Scopes(c),
ResponseTypesSupported: ResponseTypes(c),
GrantTypesSupported: GrantTypes(c),
SubjectTypesSupported: SubjectTypes(c),
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(c),
TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(c),
TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(c),
IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(c),
IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(c),
RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(c),
RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(c),
ClaimsSupported: SupportedClaims(c),
CodeChallengeMethodsSupported: CodeChallengeMethods(c),
UILocalesSupported: c.SupportedUILocales(),
RequestParameterSupported: c.RequestObjectSupported(),
}
}
@ -45,6 +52,7 @@ var DefaultSupportedScopes = []string{
oidc.ScopeEmail,
oidc.ScopePhone,
oidc.ScopeAddress,
oidc.ScopeOfflineAccess,
}
func Scopes(c Configuration) []string {
@ -127,6 +135,13 @@ func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
return authMethods
}
func TokenSigAlgorithms(c Configuration) []string {
if !c.AuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.TokenEndpointSigningAlgorithmsSupported()
}
func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodBasic,
@ -137,6 +152,20 @@ func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
return authMethods
}
func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
authMethods := []oidc.AuthMethod{
oidc.AuthMethodNone,
oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
authMethods = append(authMethods, oidc.AuthMethodPost)
}
if c.AuthMethodPrivateKeyJWTSupported() {
authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
if c.CodeMethodS256Supported() {
@ -144,3 +173,24 @@ func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
}
return codeMethods
}
func IntrospectionSigAlgorithms(c Configuration) []string {
if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.IntrospectionEndpointSigningAlgorithmsSupported()
}
func RevocationSigAlgorithms(c Configuration) []string {
if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
return nil
}
return c.RevocationEndpointSigningAlgorithmsSupported()
}
func RequestObjectSigAlgorithms(c Configuration) []string {
if !c.RequestObjectSupported() {
return nil
}
return c.RequestObjectSigningAlgorithmsSupported()
}

View file

@ -37,7 +37,10 @@ func TestDiscover(t *testing.T) {
op.Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, `{"issuer":"https://issuer.com","request_uri_parameter_supported":false}`, rec.Body.String())
require.Equal(t,
`{"issuer":"https://issuer.com","request_uri_parameter_supported":false}
`,
rec.Body.String())
})
}
}

View file

@ -1,105 +1,46 @@
package op
import (
"fmt"
"net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
const (
InvalidRequest errorType = "invalid_request"
InvalidRequestURI errorType = "invalid_request_uri"
InteractionRequired errorType = "interaction_required"
ServerError errorType = "server_error"
)
var (
ErrInvalidRequest = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequest,
Description: description,
}
}
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InvalidRequestURI,
Description: description,
redirectDisabled: true,
}
}
ErrInteractionRequired = func(description string) *OAuthError {
return &OAuthError{
ErrorType: InteractionRequired,
Description: description,
}
}
ErrServerError = func(description string) *OAuthError {
return &OAuthError{
ErrorType: ServerError,
Description: description,
}
}
)
type errorType string
type ErrAuthRequest interface {
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) {
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) {
if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
}
e.State = authReq.GetState()
if authReq.GetRedirectURI() == "" || e.redirectDisabled {
e := oidc.DefaultToServerError(err, err.Error())
if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
http.Error(w, e.Description, http.StatusBadRequest)
return
}
params, err := utils.URLEncodeResponse(e, encoder)
e.State = authReq.GetState()
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
}
url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
url := authReq.GetRedirectURI()
responseType := authReq.GetResponseType()
if responseType == "" || responseType == oidc.ResponseTypeCode {
url += "?" + params
} else {
url += "#" + params
}
http.Redirect(w, r, url, http.StatusFound)
}
func RequestError(w http.ResponseWriter, r *http.Request, err error) {
e, ok := err.(*OAuthError)
if !ok {
e = new(OAuthError)
e.ErrorType = ServerError
e.Description = err.Error()
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient {
status = 401
}
w.WriteHeader(http.StatusBadRequest)
utils.MarshalJSON(w, e)
}
type OAuthError struct {
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
redirectDisabled bool `json:"-" schema:"-"`
}
func (e *OAuthError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
httphelper.MarshalJSONWithStatus(w, e, status)
}

View file

@ -6,7 +6,7 @@ import (
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/utils"
httphelper "github.com/caos/oidc/pkg/http"
)
type KeyProvider interface {
@ -22,9 +22,8 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
keySet, err := k.GetKeySet(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
utils.MarshalJSON(w, err)
httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
return
}
utils.MarshalJSON(w, keySet)
httphelper.MarshalJSON(w, keySet)
}

100
pkg/op/keys_test.go Normal file
View file

@ -0,0 +1,100 @@
package op_test
import (
"crypto/rsa"
"math/big"
"net/http"
"net/http/httptest"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock"
)
func TestKeys(t *testing.T) {
type args struct {
k op.KeyProvider
}
type res struct {
statusCode int
contentType string
body string
}
tests := []struct {
name string
args args
res res
}{
{
name: "error",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
return m
}(),
},
res: res{
statusCode: http.StatusInternalServerError,
contentType: "application/json",
body: `{"error":"server_error"}
`,
},
},
{
name: "empty list",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(nil, nil)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
},
},
{
name: "list",
args: args{
k: func() op.KeyProvider {
m := mock.NewMockKeyProvider(gomock.NewController(t))
m.EXPECT().GetKeySet(gomock.Any()).Return(
&jose.JSONWebKeySet{Keys: []jose.JSONWebKey{
{
Key: &rsa.PublicKey{
N: big.NewInt(1),
E: 1,
},
KeyID: "id",
},
}},
nil,
)
return m
}(),
},
res: res{
statusCode: http.StatusOK,
contentType: "application/json",
body: `{"keys":[{"kty":"RSA","kid":"id","n":"AQ","e":"AQ"}]}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
op.Keys(w, httptest.NewRequest("GET", "/keys", nil), tt.args.k)
assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
assert.Equal(t, tt.res.contentType, w.Header().Get("content-type"))
assert.Equal(t, tt.res.body, w.Body.String())
})
}
}

View file

@ -7,8 +7,8 @@ package mock
import (
reflect "reflect"
http "github.com/caos/oidc/pkg/http"
op "github.com/caos/oidc/pkg/op"
utils "github.com/caos/oidc/pkg/utils"
gomock "github.com/golang/mock/gomock"
)
@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
}
// Decoder mocks base method.
func (m *MockAuthorizer) Decoder() utils.Decoder {
func (m *MockAuthorizer) Decoder() http.Decoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decoder")
ret0, _ := ret[0].(utils.Decoder)
ret0, _ := ret[0].(http.Decoder)
return ret0
}
@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
}
// Encoder mocks base method.
func (m *MockAuthorizer) Encoder() utils.Encoder {
func (m *MockAuthorizer) Encoder() http.Encoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encoder")
ret0, _ := ret[0].(utils.Encoder)
ret0, _ := ret[0].(http.Encoder)
return ret0
}
@ -105,6 +105,20 @@ func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
}
// RequestObjectSupported mocks base method.
func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RequestObjectSupported indicates an expected call of RequestObjectSupported.
func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
}
// Signer mocks base method.
func (m *MockAuthorizer) Signer() op.Signer {
m.ctrl.T.Helper()

View file

@ -147,6 +147,20 @@ func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
}
// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionAuthMethodPrivateKeyJWTSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// IntrospectionAuthMethodPrivateKeyJWTSupported indicates an expected call of IntrospectionAuthMethodPrivateKeyJWTSupported.
func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionAuthMethodPrivateKeyJWTSupported))
}
// IntrospectionEndpoint mocks base method.
func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint {
m.ctrl.T.Helper()
@ -161,6 +175,20 @@ func (mr *MockConfigurationMockRecorder) IntrospectionEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpoint))
}
// IntrospectionEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) IntrospectionEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IntrospectionEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// IntrospectionEndpointSigningAlgorithmsSupported indicates an expected call of IntrospectionEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
}
// Issuer mocks base method.
func (m *MockConfiguration) Issuer() string {
m.ctrl.T.Helper()
@ -189,6 +217,76 @@ func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
}
// RequestObjectSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) RequestObjectSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// RequestObjectSigningAlgorithmsSupported indicates an expected call of RequestObjectSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) RequestObjectSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSigningAlgorithmsSupported))
}
// RequestObjectSupported mocks base method.
func (m *MockConfiguration) RequestObjectSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestObjectSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RequestObjectSupported indicates an expected call of RequestObjectSupported.
func (mr *MockConfigurationMockRecorder) RequestObjectSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSupported))
}
// RevocationAuthMethodPrivateKeyJWTSupported mocks base method.
func (m *MockConfiguration) RevocationAuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationAuthMethodPrivateKeyJWTSupported")
ret0, _ := ret[0].(bool)
return ret0
}
// RevocationAuthMethodPrivateKeyJWTSupported indicates an expected call of RevocationAuthMethodPrivateKeyJWTSupported.
func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationAuthMethodPrivateKeyJWTSupported))
}
// RevocationEndpoint mocks base method.
func (m *MockConfiguration) RevocationEndpoint() op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpoint")
ret0, _ := ret[0].(op.Endpoint)
return ret0
}
// RevocationEndpoint indicates an expected call of RevocationEndpoint.
func (mr *MockConfigurationMockRecorder) RevocationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpoint))
}
// RevocationEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) RevocationEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevocationEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// RevocationEndpointSigningAlgorithmsSupported indicates an expected call of RevocationEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) RevocationEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpointSigningAlgorithmsSupported))
}
// SupportedUILocales mocks base method.
func (m *MockConfiguration) SupportedUILocales() []language.Tag {
m.ctrl.T.Helper()
@ -217,6 +315,20 @@ func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
}
// TokenEndpointSigningAlgorithmsSupported mocks base method.
func (m *MockConfiguration) TokenEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpointSigningAlgorithmsSupported")
ret0, _ := ret[0].([]string)
return ret0
}
// TokenEndpointSigningAlgorithmsSupported indicates an expected call of TokenEndpointSigningAlgorithmsSupported.
func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
}
// UserinfoEndpoint mocks base method.
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
m.ctrl.T.Helper()

View file

@ -5,3 +5,4 @@ package mock
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer
//go:generate mockgen -package mock -destination ./key.mock.go github.com/caos/oidc/pkg/op KeyProvider

51
pkg/op/mock/key.mock.go Normal file
View file

@ -0,0 +1,51 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/op (interfaces: KeyProvider)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
jose "gopkg.in/square/go-jose.v2"
)
// MockKeyProvider is a mock of KeyProvider interface.
type MockKeyProvider struct {
ctrl *gomock.Controller
recorder *MockKeyProviderMockRecorder
}
// MockKeyProviderMockRecorder is the mock recorder for MockKeyProvider.
type MockKeyProviderMockRecorder struct {
mock *MockKeyProvider
}
// NewMockKeyProvider creates a new mock instance.
func NewMockKeyProvider(ctrl *gomock.Controller) *MockKeyProvider {
mock := &MockKeyProvider{ctrl: ctrl}
mock.recorder = &MockKeyProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
return m.recorder
}
// GetKeySet mocks base method.
func (m *MockKeyProvider) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetKeySet", arg0)
ret0, _ := ret[0].(*jose.JSONWebKeySet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetKeySet indicates an expected call of GetKeySet.
func (mr *MockKeyProviderMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockKeyProvider)(nil).GetKeySet), arg0)
}

View file

@ -230,6 +230,20 @@ func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
}
// RevokeToken mocks base method.
func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidc.Error)
return ret0
}
// RevokeToken indicates an expected call of RevokeToken.
func (mr *MockStorageMockRecorder) RevokeToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockStorage)(nil).RevokeToken), arg0, arg1, arg2, arg3)
}
// SaveAuthCode mocks base method.
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()

View file

@ -12,8 +12,8 @@ import (
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
const (
@ -23,6 +23,7 @@ const (
defaultTokenEndpoint = "oauth/token"
defaultIntrospectEndpoint = "oauth/introspect"
defaultUserinfoEndpoint = "userinfo"
defaultRevocationEndpoint = "revoke"
defaultEndSessionEndpoint = "end_session"
defaultKeysEndpoint = "keys"
)
@ -33,6 +34,7 @@ var (
Token: NewEndpoint(defaultTokenEndpoint),
Introspection: NewEndpoint(defaultIntrospectEndpoint),
Userinfo: NewEndpoint(defaultUserinfoEndpoint),
Revocation: NewEndpoint(defaultRevocationEndpoint),
EndSession: NewEndpoint(defaultEndSessionEndpoint),
JwksURI: NewEndpoint(defaultKeysEndpoint),
}
@ -41,8 +43,8 @@ var (
type OpenIDProvider interface {
Configuration
Storage() Storage
Decoder() utils.Decoder
Encoder() utils.Encoder
Decoder() httphelper.Decoder
Encoder() httphelper.Encoder
IDTokenHintVerifier() IDTokenHintVerifier
AccessTokenVerifier() AccessTokenVerifier
Crypto() Crypto
@ -74,6 +76,7 @@ func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router
router.Handle(o.TokenEndpoint().Relative(), intercept(tokenHandler(o)))
router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
router.Handle(o.EndSessionEndpoint().Relative(), intercept(endSessionHandler(o)))
router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
return router
@ -84,8 +87,10 @@ type Config struct {
CryptoKey [32]byte
DefaultLogoutRedirectURI string
CodeMethodS256 bool
AuthMethodPost bool
AuthMethodPrivateKeyJWT bool
GrantTypeRefreshToken bool
RequestObjectSupported bool
SupportedUILocales []language.Tag
}
@ -94,6 +99,7 @@ type endpoints struct {
Token Endpoint
Introspection Endpoint
Userinfo Endpoint
Revocation Endpoint
EndSession Endpoint
CheckSessionIframe Endpoint
JwksURI Endpoint
@ -148,7 +154,6 @@ type openidProvider struct {
decoder *schema.Decoder
encoder *schema.Encoder
interceptors []HttpInterceptor
retry func(int) (bool, int)
timer <-chan time.Time
}
@ -172,6 +177,10 @@ func (o *openidProvider) UserinfoEndpoint() Endpoint {
return o.endpoints.Userinfo
}
func (o *openidProvider) RevocationEndpoint() Endpoint {
return o.endpoints.Revocation
}
func (o *openidProvider) EndSessionEndpoint() Endpoint {
return o.endpoints.EndSession
}
@ -181,7 +190,7 @@ func (o *openidProvider) KeysEndpoint() Endpoint {
}
func (o *openidProvider) AuthMethodPostSupported() bool {
return true //todo: config
return o.config.AuthMethodPost
}
func (o *openidProvider) CodeMethodS256Supported() bool {
@ -192,6 +201,10 @@ func (o *openidProvider) AuthMethodPrivateKeyJWTSupported() bool {
return o.config.AuthMethodPrivateKeyJWT
}
func (o *openidProvider) TokenEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) GrantTypeRefreshTokenSupported() bool {
return o.config.GrantTypeRefreshToken
}
@ -204,6 +217,30 @@ func (o *openidProvider) GrantTypeJWTAuthorizationSupported() bool {
return true
}
func (o *openidProvider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
return true
}
func (o *openidProvider) RevocationEndpointSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) RequestObjectSupported() bool {
return o.config.RequestObjectSupported
}
func (o *openidProvider) RequestObjectSigningAlgorithmsSupported() []string {
return []string{"RS256"}
}
func (o *openidProvider) SupportedUILocales() []language.Tag {
return o.config.SupportedUILocales
}
@ -212,11 +249,11 @@ func (o *openidProvider) Storage() Storage {
return o.storage
}
func (o *openidProvider) Decoder() utils.Decoder {
func (o *openidProvider) Decoder() httphelper.Decoder {
return o.decoder
}
func (o *openidProvider) Encoder() utils.Encoder {
func (o *openidProvider) Encoder() httphelper.Encoder {
return o.encoder
}
@ -332,6 +369,16 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomRevocationEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
if err := endpoint.Validate(); err != nil {
return err
}
o.endpoints.Revocation = endpoint
return nil
}
}
func WithCustomEndSessionEndpoint(endpoint Endpoint) Option {
return func(o *openidProvider) error {
if err := endpoint.Validate(); err != nil {
@ -352,11 +399,12 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option {
}
}
func WithCustomEndpoints(auth, token, userInfo, endSession, keys Endpoint) Option {
func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option {
return func(o *openidProvider) error {
o.endpoints.Authorization = auth
o.endpoints.Token = token
o.endpoints.Userinfo = userInfo
o.endpoints.Revocation = revocation
o.endpoints.EndSession = endSession
o.endpoints.JwksURI = keys
return nil

View file

@ -5,7 +5,7 @@ import (
"errors"
"net/http"
"github.com/caos/oidc/pkg/utils"
httphelper "github.com/caos/oidc/pkg/http"
)
type ProbesFn func(context.Context) error
@ -49,7 +49,7 @@ func ReadyStorage(s Storage) ProbesFn {
}
func ok(w http.ResponseWriter) {
utils.MarshalJSON(w, status{"ok"})
httphelper.MarshalJSON(w, status{"ok"})
}
type status struct {

View file

@ -4,12 +4,12 @@ import (
"context"
"net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type SessionEnder interface {
Decoder() utils.Decoder
Decoder() httphelper.Decoder
Storage() Storage
IDTokenHintVerifier() IDTokenHintVerifier
DefaultLogoutRedirectURI() string
@ -38,21 +38,21 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
}
err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID)
if err != nil {
RequestError(w, r, ErrServerError("error terminating session"))
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"))
return
}
http.Redirect(w, r, session.RedirectURI, http.StatusFound)
}
func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) {
func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("error parsing form")
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
req := new(oidc.EndSessionRequest)
err = decoder.Decode(req, r.Form)
if err != nil {
return nil, ErrInvalidRequest("error decoding form")
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return req, nil
}
@ -64,12 +64,12 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
}
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier())
if err != nil {
return nil, ErrInvalidRequest("id_token_hint invalid")
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
}
session.UserID = claims.GetSubject()
session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.GetAuthorizedParty())
if err != nil {
return nil, ErrServerError("")
return nil, oidc.DefaultToServerError(err, "")
}
if req.PostLogoutRedirectURI == "" {
session.RedirectURI = ender.DefaultLogoutRedirectURI()
@ -81,5 +81,5 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
return session, nil
}
}
return nil, ErrInvalidRequest("post_logout_redirect_uri invalid")
return nil, oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
}

View file

@ -20,7 +20,8 @@ type AuthStorage interface {
CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error)
TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (RefreshTokenRequest, error)
TerminateSession(context.Context, string, string) error
TerminateSession(ctx context.Context, userID string, clientID string) error
RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error
GetSigningKey(context.Context, chan<- jose.SigningKey)
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)

View file

@ -4,8 +4,9 @@ import (
"context"
"time"
"github.com/caos/oidc/pkg/crypto"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
"github.com/caos/oidc/pkg/strings"
)
type TokenCreator interface {
@ -64,7 +65,7 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag
func needsRefreshToken(tokenRequest TokenRequest, client Client) bool {
switch req := tokenRequest.(type) {
case AuthRequest:
return utils.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
case RefreshTokenRequest:
return true
default:
@ -104,7 +105,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
}
claims.SetPrivateClaims(privateClaims)
}
return utils.Sign(claims, signer.Signer())
return crypto.Sign(claims, signer.Signer())
}
type IDTokenRequest interface {
@ -151,7 +152,7 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
claims.SetCodeHash(codeHash)
}
return utils.Sign(claims, signer.Signer())
return crypto.Sign(claims, signer.Signer())
}
func removeUserinfoScopes(scopes []string) []string {
@ -167,5 +168,5 @@ func removeUserinfoScopes(scopes []string) []string {
newScopeList = append(newScopeList, scope)
}
}
return scopes
return newScopeList
}

View file

@ -2,11 +2,10 @@ package op
import (
"context"
"errors"
"net/http"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
//CodeExchange handles the OAuth 2.0 authorization_code grant, including
@ -17,7 +16,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err)
}
if tokenReq.Code == "" {
RequestError(w, r, ErrInvalidRequest("code missing"))
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"))
return
}
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
@ -30,11 +29,11 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
RequestError(w, r, err)
return
}
utils.MarshalJSON(w, resp)
httphelper.MarshalJSON(w, resp)
}
//ParseAccessTokenRequest parsed the http request into a oidc.AccessTokenRequest
func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) {
func ParseAccessTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AccessTokenRequest, error) {
request := new(oidc.AccessTokenRequest)
err := ParseAuthenticatedTokenRequest(r, decoder, request)
if err != nil {
@ -51,13 +50,13 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR
return nil, nil, err
}
if client.GetID() != authReq.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code")
return nil, nil, oidc.ErrInvalidGrant()
}
if !ValidateGrantType(client, oidc.GrantTypeCode) {
return nil, nil, ErrInvalidRequest("invalid_grant")
return nil, nil, oidc.ErrUnauthorizedClient()
}
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
return nil, nil, ErrInvalidRequest("redirect_uri does not correspond")
return nil, nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
}
return authReq, client, nil
}
@ -68,7 +67,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
return nil, nil, errors.New("auth_method private_key_jwt not supported")
return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
if err != nil {
@ -79,10 +78,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
}
client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil {
return nil, nil, err
return nil, nil, oidc.ErrInvalidClient().WithParent(err)
}
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant")
return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
}
if client.AuthMethod() == oidc.AuthMethodNone {
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
@ -93,9 +92,12 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return request, client, err
}
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported")
return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
if err != nil {
return nil, nil, err
}
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err
}
@ -104,7 +106,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) {
authReq, err := storage.AuthRequestByCode(ctx, code)
if err != nil {
return nil, ErrInvalidRequest("invalid code")
return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err)
}
return authReq, nil
}

View file

@ -3,28 +3,9 @@ package op
import (
"errors"
"net/http"
"github.com/caos/oidc/pkg/oidc"
)
//TokenExchange will handle the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
tokenRequest, err := ParseTokenExchangeRequest(w, r)
if err != nil {
RequestError(w, r, err)
return
}
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
if err != nil {
RequestError(w, r, err)
return
}
}
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
return nil, errors.New("Unimplemented") //TODO: impl
}
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
return errors.New("Unimplemented") //TODO: impl
RequestError(w, r, errors.New("unimplemented"))
}

View file

@ -5,12 +5,12 @@ import (
"net/http"
"net/url"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type Introspector interface {
Decoder() utils.Decoder
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
@ -36,16 +36,16 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
}
tokenID, subject, ok := getTokenIDAndSubject(r.Context(), introspector, token)
if !ok {
utils.MarshalJSON(w, response)
httphelper.MarshalJSON(w, response)
return
}
err = introspector.Storage().SetIntrospectionFromToken(r.Context(), response, tokenID, subject, clientID)
if err != nil {
utils.MarshalJSON(w, response)
httphelper.MarshalJSON(w, response)
return
}
response.SetActive(true)
utils.MarshalJSON(w, response)
httphelper.MarshalJSON(w, response)
}
func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {

View file

@ -5,8 +5,8 @@ import (
"net/http"
"time"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type JWTAuthorizationGrantExchanger interface {
@ -37,18 +37,18 @@ func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizati
RequestError(w, r, err)
return
}
utils.MarshalJSON(w, resp)
httphelper.MarshalJSON(w, resp)
}
func ParseJWTProfileGrantRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) {
func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("error parsing form")
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
tokenReq := new(oidc.JWTProfileGrantRequest)
err = decoder.Decode(tokenReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest("error decoding form")
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return tokenReq, nil
}
@ -74,6 +74,6 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
//ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest
//
//deprecated: use ParseJWTProfileGrantRequest
func ParseJWTProfileRequest(r *http.Request, decoder utils.Decoder) (*oidc.JWTProfileGrantRequest, error) {
func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
return ParseJWTProfileGrantRequest(r, decoder)
}

View file

@ -6,8 +6,9 @@ import (
"net/http"
"time"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
"github.com/caos/oidc/pkg/strings"
)
type RefreshTokenRequest interface {
@ -37,11 +38,11 @@ func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exch
RequestError(w, r, err)
return
}
utils.MarshalJSON(w, resp)
httphelper.MarshalJSON(w, resp)
}
//ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest
func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.RefreshTokenRequest, error) {
func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.RefreshTokenRequest, error) {
request := new(oidc.RefreshTokenRequest)
err := ParseAuthenticatedTokenRequest(r, decoder, request)
if err != nil {
@ -54,14 +55,14 @@ func ParseRefreshTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.Ref
//and returns the data representing the original auth request corresponding to the refresh_token
func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) {
if tokenReq.RefreshToken == "" {
return nil, nil, ErrInvalidRequest("code missing")
return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing")
}
request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger)
if err != nil {
return nil, nil, err
}
if client.GetID() != request.GetClientID() {
return nil, nil, ErrInvalidRequest("invalid auth code")
return nil, nil, oidc.ErrInvalidGrant()
}
if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil {
return nil, nil, err
@ -77,15 +78,15 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok
return nil
}
for _, scope := range requestedScopes {
if !utils.Contains(authRequest.GetScopes(), scope) {
return errors.New("invalid_scope")
if !strings.Contains(authRequest.GetScopes(), scope) {
return oidc.ErrInvalidScope()
}
}
authRequest.SetCurrentScopes(requestedScopes)
return nil
}
//AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered.
//AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered.
//It than returns the data representing the original auth request corresponding to the refresh_token
func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) {
if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
@ -98,7 +99,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err
}
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant")
return nil, nil, oidc.ErrUnauthorizedClient()
}
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err
@ -108,17 +109,17 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
return nil, nil, err
}
if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
return nil, nil, ErrInvalidRequest("invalid_grant")
return nil, nil, oidc.ErrUnauthorizedClient()
}
if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
return nil, nil, errors.New("invalid_grant")
return nil, nil, oidc.ErrInvalidClient()
}
if client.AuthMethod() == oidc.AuthMethodNone {
request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
return request, client, err
}
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, errors.New("auth_method post not supported")
return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil {
return nil, nil, err
@ -132,7 +133,7 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ
func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) {
request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken)
if err != nil {
return nil, ErrInvalidRequest("invalid refreshToken")
return nil, oidc.ErrInvalidGrant().WithParent(err)
}
return request, nil
}

View file

@ -5,14 +5,14 @@ import (
"net/http"
"net/url"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type Exchanger interface {
Issuer() string
Storage() Storage
Decoder() utils.Decoder
Decoder() httphelper.Decoder
Signer() Signer
Crypto() Crypto
AuthMethodPostSupported() bool
@ -24,7 +24,8 @@ type Exchanger interface {
func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.FormValue("grant_type") {
grantType := r.FormValue("grant_type")
switch grantType {
case string(oidc.GrantTypeCode):
CodeExchange(w, r, exchanger)
return
@ -44,14 +45,14 @@ func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Reque
return
}
case "":
RequestError(w, r, ErrInvalidRequest("grant_type missing"))
RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"))
return
}
RequestError(w, r, ErrInvalidRequest("grant_type not supported"))
RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType))
}
}
//authenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest
//AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest
//it is implemented by oidc.AuthRequest and oidc.RefreshTokenRequest
type AuthenticatedTokenRequest interface {
SetClientID(string)
@ -60,48 +61,49 @@ type AuthenticatedTokenRequest interface {
//ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either
//HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface
func ParseAuthenticatedTokenRequest(r *http.Request, decoder utils.Decoder, request AuthenticatedTokenRequest) error {
func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, request AuthenticatedTokenRequest) error {
err := r.ParseForm()
if err != nil {
return ErrInvalidRequest("error parsing form")
return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
err = decoder.Decode(request, r.Form)
if err != nil {
return ErrInvalidRequest("error decoding form")
return oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return ErrInvalidRequest("invalid basic auth header")
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return ErrInvalidRequest("invalid basic auth header")
}
request.SetClientID(clientID)
request.SetClientSecret(clientSecret)
if !ok {
return nil
}
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
request.SetClientID(clientID)
request.SetClientSecret(clientSecret)
return nil
}
//AuthorizeRefreshClientByClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST)
//AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST)
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error {
err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
if err != nil {
return err //TODO: wrap?
return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err)
}
return nil
}
//AuthorizeCodeClientByCodeChallenge authorizes a client by validating the code_verifier against the previously sent
//AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
//code_challenge of the auth request (PKCE)
func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error {
if tokenReq.CodeVerifier == "" {
return ErrInvalidRequest("code_challenge required")
return oidc.ErrInvalidRequest().WithDescription("code_challenge required")
}
if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) {
return ErrInvalidRequest("code_challenge invalid")
return oidc.ErrInvalidGrant().WithDescription("invalid code challenge")
}
return nil
}
@ -118,7 +120,7 @@ func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchang
return nil, err
}
if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT {
return nil, ErrInvalidRequest("invalid_client")
return nil, oidc.ErrInvalidClient()
}
return client, nil
}

136
pkg/op/token_revocation.go Normal file
View file

@ -0,0 +1,136 @@
package op
import (
"context"
"net/http"
"net/url"
"strings"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
)
type Revoker interface {
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
AuthMethodPrivateKeyJWTSupported() bool
AuthMethodPostSupported() bool
}
type RevokerJWTProfile interface {
Revoker
JWTProfileVerifier() JWTProfileVerifier
}
func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Revoke(w, r, revoker)
}
}
func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) {
token, _, clientID, err := ParseTokenRevocationRequest(r, revoker)
if err != nil {
RevocationRequestError(w, r, err)
return
}
tokenID, subject, ok := getTokenIDAndSubjectForRevocation(r.Context(), revoker, token)
if ok {
token = tokenID
}
if err := revoker.Storage().RevokeToken(r.Context(), token, subject, clientID); err != nil {
RevocationRequestError(w, r, err)
return
}
httphelper.MarshalJSON(w, nil)
}
func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) {
err = r.ParseForm()
if err != nil {
return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err)
}
req := new(struct {
oidc.RevocationRequest
oidc.ClientAssertionParams //for auth_method private_key_jwt
ClientID string `schema:"client_id"` //for auth_method none and post
ClientSecret string `schema:"client_secret"` //for auth_method post
})
err = revoker.Decoder().Decode(req, r.Form)
if err != nil {
return "", "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
if req.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
revokerJWTProfile, ok := revoker.(RevokerJWTProfile)
if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
}
profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier())
if err == nil {
return req.Token, req.TokenTypeHint, profile.Issuer, nil
}
return "", "", "", err
}
clientID, clientSecret, ok := r.BasicAuth()
if ok {
clientID, err = url.QueryUnescape(clientID)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
clientSecret, err = url.QueryUnescape(clientSecret)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
}
if err = AuthorizeClientIDSecret(r.Context(), clientID, clientSecret, revoker.Storage()); err != nil {
return "", "", "", err
}
return req.Token, req.TokenTypeHint, clientID, nil
}
if req.ClientID == "" {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
}
client, err := revoker.Storage().GetClientByClientID(r.Context(), req.ClientID)
if err != nil {
return "", "", "", oidc.ErrInvalidClient().WithParent(err)
}
if req.ClientSecret == "" {
if client.AuthMethod() != oidc.AuthMethodNone {
return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
}
return req.Token, req.TokenTypeHint, req.ClientID, nil
}
if client.AuthMethod() == oidc.AuthMethodPost && !revoker.AuthMethodPostSupported() {
return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
}
if err = AuthorizeClientIDSecret(r.Context(), req.ClientID, req.ClientSecret, revoker.Storage()); err != nil {
return "", "", "", err
}
return req.Token, req.TokenTypeHint, req.ClientID, nil
}
func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
e := oidc.DefaultToServerError(err, err.Error())
status := http.StatusBadRequest
if e.ErrorType == oidc.InvalidClient {
status = 401
}
httphelper.MarshalJSONWithStatus(w, e, status)
}
func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil {
splitToken := strings.Split(tokenIDSubject, ":")
if len(splitToken) != 2 {
return "", "", false
}
return splitToken[0], splitToken[1], true
}
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier())
if err != nil {
return "", "", false
}
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true
}

View file

@ -6,12 +6,12 @@ import (
"net/http"
"strings"
httphelper "github.com/caos/oidc/pkg/http"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils"
)
type UserinfoProvider interface {
Decoder() utils.Decoder
Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
AccessTokenVerifier() AccessTokenVerifier
@ -37,14 +37,13 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
info := oidc.NewUserInfo()
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
w.WriteHeader(http.StatusForbidden)
utils.MarshalJSON(w, err)
httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
return
}
utils.MarshalJSON(w, info)
httphelper.MarshalJSON(w, info)
}
func ParseUserinfoRequest(r *http.Request, decoder utils.Decoder) (string, error) {
func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) {
accessToken, err := getAccessToken(r)
if err == nil {
return accessToken, nil

View file

@ -49,7 +49,7 @@ func (i *accessTokenVerifier) KeySet() oidc.KeySet {
}
func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet) AccessTokenVerifier {
verifier := &idTokenHintVerifier{
verifier := &accessTokenVerifier{
issuer: issuer,
keySet: keySet,
}