initial commit
This commit is contained in:
commit
6d0890e280
68 changed files with 5986 additions and 0 deletions
151
pkg/oidc/authorization.go
Normal file
151
pkg/oidc/authorization.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
const (
|
||||
ScopeOpenID = "openid"
|
||||
|
||||
ResponseTypeCode ResponseType = "code"
|
||||
ResponseTypeIDToken ResponseType = "id_token token"
|
||||
ResponseTypeIDTokenOnly ResponseType = "id_token"
|
||||
|
||||
DisplayPage Display = "page"
|
||||
DisplayPopup Display = "popup"
|
||||
DisplayTouch Display = "touch"
|
||||
DisplayWAP Display = "wap"
|
||||
|
||||
PromptNone Prompt = "none"
|
||||
PromptLogin Prompt = "login"
|
||||
PromptConsent Prompt = "consent"
|
||||
PromptSelectAccount Prompt = "select_account"
|
||||
|
||||
GrantTypeCode GrantType = "authorization_code"
|
||||
|
||||
BearerToken = "Bearer"
|
||||
)
|
||||
|
||||
var displayValues = map[string]Display{
|
||||
"page": DisplayPage,
|
||||
"popup": DisplayPopup,
|
||||
"touch": DisplayTouch,
|
||||
"wap": DisplayWAP,
|
||||
}
|
||||
|
||||
//AuthRequest according to:
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
//
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
Scopes Scopes `schema:"scope"`
|
||||
ResponseType ResponseType `schema:"response_type"`
|
||||
ClientID string `schema:"client_id"`
|
||||
RedirectURI string `schema:"redirect_uri"` //TODO: type
|
||||
|
||||
State string `schema:"state"`
|
||||
|
||||
// ResponseMode TODO: ?
|
||||
|
||||
Nonce string `schema:"nonce"`
|
||||
Display Display `schema:"display"`
|
||||
Prompt Prompt `schema:"prompt"`
|
||||
MaxAge uint32 `schema:"max_age"`
|
||||
UILocales Locales `schema:"ui_locales"`
|
||||
IDTokenHint string `schema:"id_token_hint"`
|
||||
LoginHint string `schema:"login_hint"`
|
||||
ACRValues []string `schema:"acr_values"`
|
||||
|
||||
CodeChallenge string `schema:"code_challenge"`
|
||||
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
}
|
||||
func (a *AuthRequest) GetResponseType() ResponseType {
|
||||
return a.ResponseType
|
||||
}
|
||||
func (a *AuthRequest) GetState() string {
|
||||
return a.State
|
||||
}
|
||||
|
||||
type TokenRequest interface {
|
||||
// GrantType GrantType `schema:"grant_type"`
|
||||
GrantType() GrantType
|
||||
}
|
||||
|
||||
type TokenRequestType GrantType
|
||||
|
||||
type AccessTokenRequest struct {
|
||||
Code string `schema:"code"`
|
||||
RedirectURI string `schema:"redirect_uri"`
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"`
|
||||
CodeVerifier string `schema:"code_verifier"`
|
||||
}
|
||||
|
||||
func (a *AccessTokenRequest) GrantType() GrantType {
|
||||
return GrantTypeCode
|
||||
}
|
||||
|
||||
type AccessTokenResponse struct {
|
||||
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
|
||||
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
subjectToken string `schema:"subject_token"`
|
||||
subjectTokenType string `schema:"subject_token_type"`
|
||||
actorToken string `schema:"actor_token"`
|
||||
actorTokenType string `schema:"actor_token_type"`
|
||||
resource []string `schema:"resource"`
|
||||
audience []string `schema:"audience"`
|
||||
Scope []string `schema:"scope"`
|
||||
requestedTokenType string `schema:"requested_token_type"`
|
||||
}
|
||||
|
||||
type Scopes []string
|
||||
|
||||
func (s *Scopes) UnmarshalText(text []byte) error {
|
||||
scopes := strings.Split(string(text), " ")
|
||||
*s = Scopes(scopes)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResponseType string
|
||||
|
||||
type Display string
|
||||
|
||||
func (d *Display) UnmarshalText(text []byte) error {
|
||||
var ok bool
|
||||
display := string(text)
|
||||
*d, ok = displayValues[display]
|
||||
if !ok {
|
||||
return errors.New("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prompt string
|
||||
|
||||
type Locales []language.Tag
|
||||
|
||||
func (l *Locales) UnmarshalText(text []byte) error {
|
||||
locales := strings.Split(string(text), " ")
|
||||
for _, locale := range locales {
|
||||
tag, err := language.Parse(locale)
|
||||
if err == nil && !tag.IsRoot() {
|
||||
*l = append(*l, tag)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GrantType string
|
33
pkg/oidc/code_challenge.go
Normal file
33
pkg/oidc/code_challenge.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
CodeChallengeMethodPlain CodeChallengeMethod = "plain"
|
||||
CodeChallengeMethodS256 CodeChallengeMethod = "S256"
|
||||
)
|
||||
|
||||
type CodeChallengeMethod string
|
||||
|
||||
type CodeChallenge struct {
|
||||
Challenge string
|
||||
Method CodeChallengeMethod
|
||||
}
|
||||
|
||||
func NewSHACodeChallenge(code string) string {
|
||||
return utils.HashString(sha256.New(), code)
|
||||
}
|
||||
|
||||
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
|
||||
if c == nil {
|
||||
return false //TODO: ?
|
||||
}
|
||||
if c.Method == CodeChallengeMethodS256 {
|
||||
codeVerifier = NewSHACodeChallenge(codeVerifier)
|
||||
}
|
||||
return codeVerifier == c.Challenge
|
||||
}
|
24
pkg/oidc/discovery.go
Normal file
24
pkg/oidc/discovery.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package oidc
|
||||
|
||||
const (
|
||||
DiscoveryEndpoint = "/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
type DiscoveryConfiguration struct {
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
|
||||
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||
IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
||||
JwksURI string `json:"jwks_uri,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
|
||||
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
|
||||
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
ClaimsSupported []string `json:"claims_supported,omitempty"`
|
||||
}
|
33
pkg/oidc/grants/client_credentials.go
Normal file
33
pkg/oidc/grants/client_credentials.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package grants
|
||||
|
||||
import "strings"
|
||||
|
||||
type clientCredentialsGrantBasic struct {
|
||||
grantType string `schema:"grant_type"`
|
||||
scope string `schema:"scope"`
|
||||
}
|
||||
|
||||
type clientCredentialsGrant struct {
|
||||
*clientCredentialsGrantBasic
|
||||
clientID string `schema:"client_id"`
|
||||
clientSecret string `schema:"client_secret"`
|
||||
}
|
||||
|
||||
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
|
||||
//sneding client_id and client_secret as basic auth header
|
||||
func ClientCredentialsGrantBasic(scopes ...string) *clientCredentialsGrantBasic {
|
||||
return &clientCredentialsGrantBasic{
|
||||
grantType: "client_credentials",
|
||||
scope: strings.Join(scopes, " "),
|
||||
}
|
||||
}
|
||||
|
||||
//ClientCredentialsGrantValues creates an oauth2 `Client Credentials` Grant
|
||||
//sneding client_id and client_secret as form values
|
||||
func ClientCredentialsGrantValues(clientID, clientSecret string, scopes ...string) *clientCredentialsGrant {
|
||||
return &clientCredentialsGrant{
|
||||
clientCredentialsGrantBasic: ClientCredentialsGrantBasic(scopes...),
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
}
|
||||
}
|
75
pkg/oidc/grants/tokenexchange/tokenexchange.go
Normal file
75
pkg/oidc/grants/tokenexchange/tokenexchange.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package tokenexchange
|
||||
|
||||
const (
|
||||
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
||||
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
||||
IDTokenType = "urn:ietf:params:oauth:token-type:id_token"
|
||||
JWTTokenType = "urn:ietf:params:oauth:token-type:jwt"
|
||||
DelegationTokenType = AccessTokenType
|
||||
|
||||
TokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
|
||||
)
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
grantType string `schema:"grant_type"`
|
||||
subjectToken string `schema:"subject_token"`
|
||||
subjectTokenType string `schema:"subject_token_type"`
|
||||
actorToken string `schema:"actor_token"`
|
||||
actorTokenType string `schema:"actor_token_type"`
|
||||
resource []string `schema:"resource"`
|
||||
audience []string `schema:"audience"`
|
||||
scope []string `schema:"scope"`
|
||||
requestedTokenType string `schema:"requested_token_type"`
|
||||
}
|
||||
|
||||
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
||||
t := &TokenExchangeRequest{
|
||||
grantType: TokenExchangeGrantType,
|
||||
subjectToken: subjectToken,
|
||||
subjectTokenType: subjectTokenType,
|
||||
requestedTokenType: AccessTokenType,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type TokenExchangeOption func(*TokenExchangeRequest)
|
||||
|
||||
func WithActorToken(token, tokenType string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.actorToken = token
|
||||
req.actorTokenType = tokenType
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudience(audience []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.audience = audience
|
||||
}
|
||||
}
|
||||
|
||||
func WithGrantType(grantType string) TokenExchangeOption {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.grantType = grantType
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestedTokenType(tokenType string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.requestedTokenType = tokenType
|
||||
}
|
||||
}
|
||||
|
||||
func WithResource(resource []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.resource = resource
|
||||
}
|
||||
}
|
||||
|
||||
func WithScope(scope []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.scope = scope
|
||||
}
|
||||
}
|
22
pkg/oidc/keyset.go
Normal file
22
pkg/oidc/keyset.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
|
||||
// of JSON web tokens. This is expected to be backed by a remote key set through
|
||||
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
|
||||
type KeySet interface {
|
||||
// VerifySignature parses the JSON web token, verifies the signature, and returns
|
||||
// the raw payload. Header and claim fields are validated by other parts of the
|
||||
// package. For example, the KeySet does not need to check values such as signature
|
||||
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
|
||||
// independently.
|
||||
//
|
||||
// If VerifySignature makes HTTP requests to verify the token, it's expected to
|
||||
// use any HTTP client associated with the context through ClientContext.
|
||||
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
|
||||
}
|
196
pkg/oidc/token.go
Normal file
196
pkg/oidc/token.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
type Tokens struct {
|
||||
*oauth2.Token
|
||||
IDTokenClaims *IDTokenClaims
|
||||
IDToken string
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
IssuedAt time.Time
|
||||
NotBefore time.Time
|
||||
JWTID string
|
||||
AuthorizedParty string
|
||||
Nonce string
|
||||
AuthTime time.Time
|
||||
CodeHash string
|
||||
AuthenticationContextClassReference string
|
||||
AuthenticationMethodsReferences []string
|
||||
SessionID string
|
||||
Scopes []string
|
||||
ClientID string
|
||||
AccessTokenUseNumber int
|
||||
}
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
NotBefore time.Time
|
||||
IssuedAt time.Time
|
||||
JWTID string
|
||||
UpdatedAt time.Time
|
||||
AuthorizedParty string
|
||||
Nonce string
|
||||
AuthTime time.Time
|
||||
AccessTokenHash string
|
||||
CodeHash string
|
||||
AuthenticationContextClassReference string
|
||||
AuthenticationMethodsReferences []string
|
||||
ClientID string
|
||||
|
||||
Signature jose.SignatureAlgorithm //TODO: ???
|
||||
}
|
||||
|
||||
type jsonToken struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audiences []string `json:"aud,omitempty"`
|
||||
Expiration int64 `json:"exp,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Actor interface{} `json:"act,omitempty"` //TODO: impl
|
||||
Scopes string `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
|
||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||
}
|
||||
|
||||
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
j := jsonToken{
|
||||
Issuer: t.Issuer,
|
||||
Subject: t.Subject,
|
||||
Audiences: t.Audiences,
|
||||
Expiration: timeToJSON(t.Expiration),
|
||||
NotBefore: timeToJSON(t.NotBefore),
|
||||
IssuedAt: timeToJSON(t.IssuedAt),
|
||||
JWTID: t.JWTID,
|
||||
AuthorizedParty: t.AuthorizedParty,
|
||||
Nonce: t.Nonce,
|
||||
AuthTime: timeToJSON(t.AuthTime),
|
||||
CodeHash: t.CodeHash,
|
||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||
SessionID: t.SessionID,
|
||||
Scopes: strings.Join(t.Scopes, " "),
|
||||
ClientID: t.ClientID,
|
||||
AccessTokenUseNumber: t.AccessTokenUseNumber,
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
|
||||
var j jsonToken
|
||||
if err := json.Unmarshal(b, &j); err != nil {
|
||||
return err
|
||||
}
|
||||
audience := j.Audiences
|
||||
if len(audience) == 1 {
|
||||
audience = strings.Split(audience[0], " ")
|
||||
}
|
||||
t.Issuer = j.Issuer
|
||||
t.Subject = j.Subject
|
||||
t.Audiences = audience
|
||||
t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
||||
t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
|
||||
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
||||
t.JWTID = j.JWTID
|
||||
t.AuthorizedParty = j.AuthorizedParty
|
||||
t.Nonce = j.Nonce
|
||||
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
|
||||
t.CodeHash = j.CodeHash
|
||||
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
|
||||
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
|
||||
t.SessionID = j.SessionID
|
||||
t.Scopes = strings.Split(j.Scopes, " ")
|
||||
t.ClientID = j.ClientID
|
||||
t.AccessTokenUseNumber = j.AccessTokenUseNumber
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
j := jsonToken{
|
||||
Issuer: t.Issuer,
|
||||
Subject: t.Subject,
|
||||
Audiences: t.Audiences,
|
||||
Expiration: timeToJSON(t.Expiration),
|
||||
NotBefore: timeToJSON(t.NotBefore),
|
||||
IssuedAt: timeToJSON(t.IssuedAt),
|
||||
JWTID: t.JWTID,
|
||||
UpdatedAt: timeToJSON(t.UpdatedAt),
|
||||
AuthorizedParty: t.AuthorizedParty,
|
||||
Nonce: t.Nonce,
|
||||
AuthTime: timeToJSON(t.AuthTime),
|
||||
AccessTokenHash: t.AccessTokenHash,
|
||||
CodeHash: t.CodeHash,
|
||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||
ClientID: t.ClientID,
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
||||
var i jsonToken
|
||||
if err := json.Unmarshal(b, &i); err != nil {
|
||||
return err
|
||||
}
|
||||
audience := i.Audiences
|
||||
if len(audience) == 1 {
|
||||
audience = strings.Split(audience[0], " ")
|
||||
}
|
||||
t.Issuer = i.Issuer
|
||||
t.Subject = i.Subject
|
||||
t.Audiences = audience
|
||||
t.Expiration = time.Unix(i.Expiration, 0).UTC()
|
||||
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
|
||||
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
|
||||
t.Nonce = i.Nonce
|
||||
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
|
||||
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
||||
t.AuthorizedParty = i.AuthorizedParty
|
||||
t.AccessTokenHash = i.AccessTokenHash
|
||||
t.CodeHash = i.CodeHash
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return utils.HashString(hash, claim), nil
|
||||
}
|
||||
|
||||
func timeToJSON(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
120
pkg/oidc/userinfo.go
Normal file
120
pkg/oidc/userinfo.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
type Userinfo struct {
|
||||
Subject string
|
||||
Address *UserinfoAddress
|
||||
UserinfoProfile
|
||||
UserinfoEmail
|
||||
UserinfoPhone
|
||||
|
||||
claims map[string]interface{}
|
||||
}
|
||||
|
||||
type UserinfoPhone struct {
|
||||
PhoneNumber string
|
||||
PhoneNumberVerified bool
|
||||
}
|
||||
type UserinfoProfile struct {
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender Gender
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale language.Tag
|
||||
UpdatedAt time.Time
|
||||
PreferredUsername string
|
||||
}
|
||||
|
||||
type Gender string
|
||||
|
||||
type UserinfoAddress struct {
|
||||
Formatted string
|
||||
StreetAddress string
|
||||
Locality string
|
||||
Region string
|
||||
PostalCode string
|
||||
Country string
|
||||
}
|
||||
|
||||
type UserinfoEmail struct {
|
||||
Email string
|
||||
EmailVerified bool
|
||||
}
|
||||
|
||||
func marshalUserinfoProfile(i UserinfoProfile, claims map[string]interface{}) {
|
||||
claims["name"] = i.Name
|
||||
claims["given_name"] = i.GivenName
|
||||
claims["family_name"] = i.FamilyName
|
||||
claims["middle_name"] = i.MiddleName
|
||||
claims["nickname"] = i.Nickname
|
||||
claims["profile"] = i.Profile
|
||||
claims["picture"] = i.Picture
|
||||
claims["website"] = i.Website
|
||||
claims["gender"] = i.Gender
|
||||
claims["birthdate"] = i.Birthdate
|
||||
claims["Zoneinfo"] = i.Zoneinfo
|
||||
claims["locale"] = i.Locale.String()
|
||||
claims["updated_at"] = i.UpdatedAt.UTC().Unix()
|
||||
claims["preferred_username"] = i.PreferredUsername
|
||||
}
|
||||
|
||||
func marshalUserinfoEmail(i UserinfoEmail, claims map[string]interface{}) {
|
||||
if i.Email != "" {
|
||||
claims["email"] = i.Email
|
||||
}
|
||||
if i.EmailVerified {
|
||||
claims["email_verified"] = i.EmailVerified
|
||||
}
|
||||
}
|
||||
|
||||
func marshalUserinfoAddress(i *UserinfoAddress, claims map[string]interface{}) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
address := make(map[string]interface{})
|
||||
if i.Formatted != "" {
|
||||
address["formatted"] = i.Formatted
|
||||
}
|
||||
if i.StreetAddress != "" {
|
||||
address["street_address"] = i.StreetAddress
|
||||
}
|
||||
claims["address"] = address
|
||||
}
|
||||
|
||||
func marshalUserinfoPhone(i UserinfoPhone, claims map[string]interface{}) {
|
||||
claims["phone_number"] = i.PhoneNumber
|
||||
claims["phone_number_verified"] = i.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (i *Userinfo) MarshalJSON() ([]byte, error) {
|
||||
claims := i.claims
|
||||
if claims == nil {
|
||||
claims = make(map[string]interface{})
|
||||
}
|
||||
claims["sub"] = i.Subject
|
||||
marshalUserinfoAddress(i.Address, claims)
|
||||
marshalUserinfoEmail(i.UserinfoEmail, claims)
|
||||
marshalUserinfoPhone(i.UserinfoPhone, claims)
|
||||
marshalUserinfoProfile(i.UserinfoProfile, claims)
|
||||
return json.Marshal(claims)
|
||||
}
|
||||
|
||||
func (i *Userinfo) UnmmarshalJSON(data []byte) error {
|
||||
if err := json.Unmarshal(data, i); err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, i.claims)
|
||||
}
|
198
pkg/op/authrequest.go
Normal file
198
pkg/op/authrequest.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Authorizer interface {
|
||||
Storage() Storage
|
||||
Decoder() *schema.Decoder
|
||||
Encoder() *schema.Encoder
|
||||
Signer() Signer
|
||||
Crypto() Crypto
|
||||
Issuer() string
|
||||
}
|
||||
|
||||
type ValidationAuthorizer interface {
|
||||
Authorizer
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
|
||||
}
|
||||
|
||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
authReq := new(oidc.AuthRequest)
|
||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
validation := ValidateAuthRequest
|
||||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||
validation = validater.ValidateAuthRequest
|
||||
}
|
||||
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
}
|
||||
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
|
||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
|
||||
return err
|
||||
}
|
||||
// if NeedsExistingSession(authReq) {
|
||||
// session, err := storage.CheckSession(authReq.IDTokenHint)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqScopes(scopes []string) error {
|
||||
if len(scopes) == 0 {
|
||||
return ErrInvalidRequest("scope missing")
|
||||
}
|
||||
if !utils.Contains(scopes, oidc.ScopeOpenID) {
|
||||
return ErrInvalidRequest("scope openid missing")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||
if uri == "" {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
|
||||
}
|
||||
client, err := storage.GetClientByClientID(ctx, client_id)
|
||||
if err != nil {
|
||||
return ErrServerError(err.Error())
|
||||
}
|
||||
if !utils.Contains(client.RedirectURIs(), uri) {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
if strings.HasPrefix(uri, "https://") {
|
||||
return nil
|
||||
}
|
||||
if responseType == oidc.ResponseTypeCode {
|
||||
if strings.HasPrefix(uri, "http://") && IsConfidentialType(client) {
|
||||
return nil
|
||||
}
|
||||
if client.ApplicationType() == ApplicationTypeNative {
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidRequest("redirect_uri not allowed")
|
||||
} else {
|
||||
if client.ApplicationType() != ApplicationTypeNative {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqResponseType(responseType oidc.ResponseType) error {
|
||||
if responseType == "" {
|
||||
return ErrInvalidRequest("response_type empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
|
||||
login := client.LoginURL(authReqID)
|
||||
http.Redirect(w, r, login, http.StatusFound)
|
||||
}
|
||||
|
||||
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
params := mux.Vars(r)
|
||||
id := params["id"]
|
||||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
if !authReq.Done() {
|
||||
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
AuthResponseCode(w, r, authReq, authorizer)
|
||||
return
|
||||
}
|
||||
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||
if authReq.GetState() != "" {
|
||||
callback = callback + "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
|
||||
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
296
pkg/op/authrequest_test.go
Normal file
296
pkg/op/authrequest_test.go
Normal file
|
@ -0,0 +1,296 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
"github.com/caos/oidc/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
// testCallback := func(t *testing.T, clienID string) callbackHandler {
|
||||
// return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
|
||||
// // require.Equal(t, clientID, client.)
|
||||
// }
|
||||
// }
|
||||
// testErr := func(t *testing.T, expected error) errorHandler {
|
||||
// return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// require.Equal(t, expected, err)
|
||||
// }
|
||||
// }
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
authorizer op.Authorizer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"parsing fails",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
&http.Request{Method: "POST", Body: nil},
|
||||
mock.NewAuthorizerExpectValid(t, true),
|
||||
// testCallback(t, ""),
|
||||
// testErr(t, ErrInvalidRequest("cannot parse form")),
|
||||
},
|
||||
},
|
||||
{
|
||||
"decoding fails",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
func() *http.Request {
|
||||
r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return r
|
||||
}(),
|
||||
mock.NewAuthorizerExpectValid(t, true),
|
||||
// testCallback(t, ""),
|
||||
// testErr(t, ErrInvalidRequest("cannot parse auth request")),
|
||||
},
|
||||
},
|
||||
// {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthRequest(t *testing.T) {
|
||||
type args struct {
|
||||
authRequest *oidc.AuthRequest
|
||||
storage op.Storage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
//TODO:
|
||||
// {
|
||||
// "oauth2 spec"
|
||||
// }
|
||||
{
|
||||
"scope missing fails",
|
||||
args{&oidc.AuthRequest{}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"scope openid missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"response_type missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"client_id missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"redirect_uri missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqScopes(t *testing.T) {
|
||||
type args struct {
|
||||
scopes []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"scopes missing fails", args{}, true,
|
||||
},
|
||||
{
|
||||
"scope openid missing fails", args{[]string{"email"}}, true,
|
||||
},
|
||||
{
|
||||
"scope ok", args{[]string{"openid"}}, false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqRedirectURI(t *testing.T) {
|
||||
type args struct {
|
||||
uri string
|
||||
clientID string
|
||||
responseType oidc.ResponseType
|
||||
storage op.OPStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"empty fails",
|
||||
args{"", "", oidc.ResponseTypeCode, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unregistered fails",
|
||||
args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"storage error fails",
|
||||
args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered http not confidential fails",
|
||||
args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered http confidential ok",
|
||||
args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow registered custom not native fails",
|
||||
args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered custom native ok",
|
||||
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered ok",
|
||||
args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered http localhost native ok",
|
||||
args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered http localhost user agent fails",
|
||||
args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"implicit flow http non localhost fails",
|
||||
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"implicit flow custom fails",
|
||||
args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthReqRedirectURI(nil, tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectToLogin(t *testing.T) {
|
||||
type args struct {
|
||||
authReqID string
|
||||
client op.Client
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"redirect ok",
|
||||
args{
|
||||
"id",
|
||||
mock.NewClientExpectAny(t, op.ApplicationTypeNative),
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("GET", "/authorize", nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
require.Equal(t, "/login?id=id", rec.Header().Get("location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback(t *testing.T) {
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
authorizer op.Authorizer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthResponse(t *testing.T) {
|
||||
type args struct {
|
||||
authReq op.AuthRequest
|
||||
authorizer op.Authorizer
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
|
||||
})
|
||||
}
|
||||
}
|
33
pkg/op/client.go
Normal file
33
pkg/op/client.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package op
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ApplicationTypeWeb ApplicationType = iota
|
||||
ApplicationTypeUserAgent
|
||||
ApplicationTypeNative
|
||||
|
||||
AccessTokenTypeBearer AccessTokenType = iota
|
||||
AccessTokenTypeJWT
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
GetID() string
|
||||
RedirectURIs() []string
|
||||
ApplicationType() ApplicationType
|
||||
GetAuthMethod() AuthMethod
|
||||
LoginURL(string) string
|
||||
AccessTokenType() AccessTokenType
|
||||
AccessTokenLifetime() time.Duration
|
||||
IDTokenLifetime() time.Duration
|
||||
}
|
||||
|
||||
func IsConfidentialType(c Client) bool {
|
||||
return c.ApplicationType() == ApplicationTypeWeb
|
||||
}
|
||||
|
||||
type ApplicationType int
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
type AccessTokenType int
|
54
pkg/op/config.go
Normal file
54
pkg/op/config.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Configuration interface {
|
||||
Issuer() string
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
UserinfoEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
|
||||
Port() string
|
||||
}
|
||||
|
||||
func ValidateIssuer(issuer string) error {
|
||||
if issuer == "" {
|
||||
return errors.New("missing issuer")
|
||||
}
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil {
|
||||
return errors.New("invalid url for issuer")
|
||||
}
|
||||
if u.Host == "" {
|
||||
return errors.New("host for issuer missing")
|
||||
}
|
||||
if u.Scheme != "https" {
|
||||
if !devLocalAllowed(u) {
|
||||
return errors.New("scheme for issuer must be `https`")
|
||||
}
|
||||
}
|
||||
if u.Fragment != "" || len(u.Query()) > 0 {
|
||||
return errors.New("no fragments or query allowed for issuer")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func devLocalAllowed(url *url.URL) bool {
|
||||
_, b := os.LookupEnv("CAOS_OIDC_DEV")
|
||||
if !b {
|
||||
return b
|
||||
}
|
||||
return url.Scheme == "http" &&
|
||||
url.Host == "localhost" ||
|
||||
url.Host == "127.0.0.1" ||
|
||||
url.Host == "::1" ||
|
||||
strings.HasPrefix(url.Host, "localhost:")
|
||||
}
|
94
pkg/op/config_test.go
Normal file
94
pkg/op/config_test.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package op
|
||||
|
||||
import "testing"
|
||||
|
||||
import "os"
|
||||
|
||||
func TestValidateIssuer(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"missing issuer fails",
|
||||
args{""},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for issuer missing fails",
|
||||
args{"https:///issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for not https fails",
|
||||
args{"http://issuer.com"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with fragment fails",
|
||||
args{"https://issuer.com/#issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with query fails",
|
||||
args{"https://issuer.com?issuer=me"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with https ok",
|
||||
args{"https://issuer.com"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
os.Setenv("CAOS_OIDC_DEV", "")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
26
pkg/op/crypto.go
Normal file
26
pkg/op/crypto.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Crypto interface {
|
||||
Encrypt(string) (string, error)
|
||||
Decrypt(string) (string, error)
|
||||
}
|
||||
|
||||
type aesCrypto struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func NewAESCrypto(key [32]byte) Crypto {
|
||||
return &aesCrypto{key: string(key[:32])}
|
||||
}
|
||||
|
||||
func (c *aesCrypto) Encrypt(s string) (string, error) {
|
||||
return utils.EncryptAES(s, c.key)
|
||||
}
|
||||
|
||||
func (c *aesCrypto) Decrypt(s string) (string, error) {
|
||||
return utils.DecryptAES(s, c.key)
|
||||
}
|
224
pkg/op/default_op.go
Normal file
224
pkg/op/default_op.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAuthorizationEndpoint = "authorize"
|
||||
defaulTokenEndpoint = "oauth/token"
|
||||
defaultIntrospectEndpoint = "introspect"
|
||||
defaultUserinfoEndpoint = "userinfo"
|
||||
defaultKeysEndpoint = "keys"
|
||||
|
||||
AuthMethodBasic AuthMethod = "client_secret_basic"
|
||||
AuthMethodPost = "client_secret_post"
|
||||
AuthMethodNone = "none"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultEndpoints = &endpoints{
|
||||
Authorization: defaultAuthorizationEndpoint,
|
||||
Token: defaulTokenEndpoint,
|
||||
IntrospectionEndpoint: defaultIntrospectEndpoint,
|
||||
Userinfo: defaultUserinfoEndpoint,
|
||||
JwksURI: defaultKeysEndpoint,
|
||||
}
|
||||
)
|
||||
|
||||
type DefaultOP struct {
|
||||
config *Config
|
||||
endpoints *endpoints
|
||||
discoveryConfig *oidc.DiscoveryConfiguration
|
||||
storage Storage
|
||||
signer Signer
|
||||
crypto Crypto
|
||||
http *http.Server
|
||||
decoder *schema.Decoder
|
||||
encoder *schema.Encoder
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Issuer string
|
||||
CryptoKey [32]byte
|
||||
// ScopesSupported: oidc.SupportedScopes,
|
||||
// ResponseTypesSupported: responseTypes,
|
||||
// GrantTypesSupported: oidc.SupportedGrantTypes,
|
||||
// ClaimsSupported: oidc.SupportedClaims,
|
||||
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
|
||||
// SubjectTypesSupported: []string{"public"},
|
||||
// TokenEndpointAuthMethodsSupported:
|
||||
Port string
|
||||
}
|
||||
|
||||
type endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
IntrospectionEndpoint Endpoint
|
||||
Userinfo Endpoint
|
||||
EndSessionEndpoint Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
}
|
||||
|
||||
type DefaultOPOpts func(o *DefaultOP) error
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Authorization = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Token = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Userinfo = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
||||
err := ValidateIssuer(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &DefaultOP{
|
||||
config: config,
|
||||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
}
|
||||
|
||||
p.signer, err = NewDefaultSigner(ctx, storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, optFunc := range opOpts {
|
||||
if err := optFunc(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
p.discoveryConfig = CreateDiscoveryConfig(p, p.signer)
|
||||
|
||||
router := CreateRouter(p)
|
||||
p.http = &http.Server{
|
||||
Addr: ":" + config.Port,
|
||||
Handler: router,
|
||||
}
|
||||
p.decoder = schema.NewDecoder()
|
||||
p.decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
p.encoder = schema.NewEncoder()
|
||||
|
||||
p.crypto = NewAESCrypto(config.CryptoKey)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Issuer() string {
|
||||
return p.config.Issuer
|
||||
}
|
||||
|
||||
func (p *DefaultOP) AuthorizationEndpoint() Endpoint {
|
||||
return p.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (p *DefaultOP) TokenEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.Token)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) UserinfoEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.Userinfo)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) KeysEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.JwksURI)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) AuthMethodPostSupported() bool {
|
||||
return true //TODO: config
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Port() string {
|
||||
return p.config.Port
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HttpHandler() *http.Server {
|
||||
return p.http
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, p.discoveryConfig)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Decoder() *schema.Decoder {
|
||||
return p.decoder
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Encoder() *schema.Encoder {
|
||||
return p.encoder
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Storage() Storage {
|
||||
return p.storage
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Signer() Signer {
|
||||
return p.signer
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Crypto() Crypto {
|
||||
return p.crypto
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) {
|
||||
Keys(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
Authorize(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) {
|
||||
AuthorizeCallback(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
|
||||
reqType := r.FormValue("grant_type")
|
||||
if reqType == "" {
|
||||
ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing"))
|
||||
return
|
||||
}
|
||||
if reqType == string(oidc.GrantTypeCode) {
|
||||
CodeExchange(w, r, p)
|
||||
return
|
||||
}
|
||||
TokenExchange(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
||||
Userinfo(w, r, p)
|
||||
}
|
49
pkg/op/default_op_test.go
Normal file
49
pkg/op/default_op_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestDefaultOP_HandleDiscovery(t *testing.T) {
|
||||
type fields struct {
|
||||
config *Config
|
||||
endpoints *endpoints
|
||||
discoveryConfig *oidc.DiscoveryConfiguration
|
||||
storage Storage
|
||||
http *http.Server
|
||||
}
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"OK", fields{config: nil, endpoints: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"https://issuer.com"}`, 200},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &DefaultOP{
|
||||
config: tt.fields.config,
|
||||
endpoints: tt.fields.endpoints,
|
||||
discoveryConfig: tt.fields.discoveryConfig,
|
||||
storage: tt.fields.storage,
|
||||
http: tt.fields.http,
|
||||
}
|
||||
p.HandleDiscovery(tt.args.w, tt.args.r)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, tt.want, rec.Body.String())
|
||||
require.Equal(t, tt.wantCode, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
119
pkg/op/discovery.go
Normal file
119
pkg/op/discovery.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
||||
utils.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: c.Issuer(),
|
||||
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
|
||||
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
|
||||
// IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
|
||||
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
|
||||
// EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint),
|
||||
// CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
|
||||
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
|
||||
ScopesSupported: Scopes(c),
|
||||
ResponseTypesSupported: ResponseTypes(c),
|
||||
GrantTypesSupported: GrantTypes(c),
|
||||
ClaimsSupported: SupportedClaims(c),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
|
||||
SubjectTypesSupported: SubjectTypes(c),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethods(c),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ScopeOpenID = "openid"
|
||||
ScopeProfile = "profile"
|
||||
ScopeEmail = "email"
|
||||
ScopePhone = "phone"
|
||||
ScopeAddress = "address"
|
||||
)
|
||||
|
||||
var DefaultSupportedScopes = []string{
|
||||
ScopeOpenID,
|
||||
ScopeProfile,
|
||||
ScopeEmail,
|
||||
ScopePhone,
|
||||
ScopeAddress,
|
||||
}
|
||||
|
||||
func Scopes(c Configuration) []string {
|
||||
return DefaultSupportedScopes //TODO: config
|
||||
}
|
||||
|
||||
func ResponseTypes(c Configuration) []string {
|
||||
return []string{
|
||||
"code",
|
||||
"id_token",
|
||||
// "code token",
|
||||
// "code id_token",
|
||||
"id_token token",
|
||||
// "code id_token token"
|
||||
}
|
||||
}
|
||||
|
||||
func GrantTypes(c Configuration) []string {
|
||||
return []string{
|
||||
"client_credentials",
|
||||
"authorization_code",
|
||||
// "password",
|
||||
"urn:ietf:params:oauth:grant-type:token-exchange",
|
||||
}
|
||||
}
|
||||
|
||||
func SupportedClaims(c Configuration) []string {
|
||||
return []string{ //TODO: config
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"c_hash",
|
||||
"at_hash",
|
||||
"act",
|
||||
"scopes",
|
||||
"client_id",
|
||||
"azp",
|
||||
"preferred_username",
|
||||
"name",
|
||||
"family_name",
|
||||
"given_name",
|
||||
"locale",
|
||||
"email",
|
||||
"email_verified",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
}
|
||||
}
|
||||
|
||||
func SigAlgorithms(s Signer) []string {
|
||||
return []string{string(s.SignatureAlgorithm())}
|
||||
}
|
||||
|
||||
func SubjectTypes(c Configuration) []string {
|
||||
return []string{"public"} //TODO: config
|
||||
}
|
||||
|
||||
func AuthMethods(c Configuration) []string {
|
||||
authMethods := []string{
|
||||
string(AuthMethodBasic),
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, string(AuthMethodPost))
|
||||
}
|
||||
return authMethods
|
||||
}
|
235
pkg/op/discovery_test.go
Normal file
235
pkg/op/discovery_test.go
Normal file
|
@ -0,0 +1,235 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
"github.com/caos/oidc/pkg/op/mock"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
config *oidc.DiscoveryConfiguration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"OK",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
&oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.Discover(tt.args.w, tt.args.config)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
s op.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *oidc.DiscoveryConfiguration
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_scopes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"default Scopes",
|
||||
args{},
|
||||
op.DefaultSupportedScopes,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("scopes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ResponseTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrantTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedClaims(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockSigner(gomock.NewController((t)))
|
||||
type args struct {
|
||||
s op.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
args{func() op.Signer {
|
||||
m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubjectTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"none",
|
||||
args{},
|
||||
[]string{"public"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AuthMethods(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController((t)))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"imlicit basic",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPostSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]string{string(op.AuthMethodBasic)},
|
||||
},
|
||||
{
|
||||
"basic and post",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPostSupported().Return(true)
|
||||
return m
|
||||
}()},
|
||||
[]string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.AuthMethods(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("authMethods() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
25
pkg/op/endpoint.go
Normal file
25
pkg/op/endpoint.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package op
|
||||
|
||||
import "strings"
|
||||
|
||||
type Endpoint string
|
||||
|
||||
func (e Endpoint) Relative() string {
|
||||
return relativeEndpoint(string(e))
|
||||
}
|
||||
|
||||
func (e Endpoint) Absolute(host string) string {
|
||||
return absoluteEndpoint(host, string(e))
|
||||
}
|
||||
|
||||
func (e Endpoint) Validate() error {
|
||||
return nil //TODO:
|
||||
}
|
||||
|
||||
func absoluteEndpoint(host, endpoint string) string {
|
||||
return strings.TrimSuffix(host, "/") + relativeEndpoint(endpoint)
|
||||
}
|
||||
|
||||
func relativeEndpoint(endpoint string) string {
|
||||
return "/" + strings.TrimPrefix(endpoint, "/")
|
||||
}
|
95
pkg/op/endpoint_test.go
Normal file
95
pkg/op/endpoint_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Relative(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"without starting /",
|
||||
op.Endpoint("test"),
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
"with starting /",
|
||||
op.Endpoint("/test"),
|
||||
"/test",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.e.Relative(); got != tt.want {
|
||||
t.Errorf("Endpoint.Relative() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpoint_Absolute(t *testing.T) {
|
||||
type args struct {
|
||||
host string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"no /",
|
||||
op.Endpoint("test"),
|
||||
args{"https://host"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"endpoint without /",
|
||||
op.Endpoint("test"),
|
||||
args{"https://host/"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"host without /",
|
||||
op.Endpoint("/test"),
|
||||
args{"https://host"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"both /",
|
||||
op.Endpoint("/test"),
|
||||
args{"https://host/"},
|
||||
"https://host/test",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.e.Absolute(tt.args.host); got != tt.want {
|
||||
t.Errorf("Endpoint.Absolute() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: impl test
|
||||
func TestEndpoint_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
99
pkg/op/error.go
Normal file
99
pkg/op/error.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
InvalidRequest errorType = "invalid_request"
|
||||
ServerError errorType = "server_error"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
redirectDisabled: true,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
|
||||
if authReq == nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
e.state = authReq.GetState()
|
||||
if authReq.GetRedirectURI() == "" || e.redirectDisabled {
|
||||
http.Error(w, e.Description, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(e, encoder)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.GetRedirectURI()
|
||||
responseType := authReq.GetResponseType()
|
||||
if responseType == "" || responseType == oidc.ResponseTypeCode {
|
||||
url += "?" + params
|
||||
} else {
|
||||
url += "#" + params
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
utils.MarshalJSON(w, e)
|
||||
}
|
||||
|
||||
type OAuthError struct {
|
||||
ErrorType errorType `json:"error" schema:"error"`
|
||||
Description string `json:"description" schema:"description"`
|
||||
state string `json:"state" schema:"state"`
|
||||
redirectDisabled bool
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
|
||||
}
|
19
pkg/op/keys.go
Normal file
19
pkg/op/keys.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type KeyProvider interface {
|
||||
Storage() Storage
|
||||
}
|
||||
|
||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||
keySet, err := k.Storage().GetKeySet(r.Context())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
utils.MarshalJSON(w, keySet)
|
||||
}
|
119
pkg/op/mock/authorizer.mock.go
Normal file
119
pkg/op/mock/authorizer.mock.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Authorizer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
schema "github.com/gorilla/schema"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockAuthorizer is a mock of Authorizer interface
|
||||
type MockAuthorizer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAuthorizerMockRecorder
|
||||
}
|
||||
|
||||
// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer
|
||||
type MockAuthorizerMockRecorder struct {
|
||||
mock *MockAuthorizer
|
||||
}
|
||||
|
||||
// NewMockAuthorizer creates a new mock instance
|
||||
func NewMockAuthorizer(ctrl *gomock.Controller) *MockAuthorizer {
|
||||
mock := &MockAuthorizer{ctrl: ctrl}
|
||||
mock.recorder = &MockAuthorizerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Crypto mocks base method
|
||||
func (m *MockAuthorizer) Crypto() op.Crypto {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Crypto")
|
||||
ret0, _ := ret[0].(op.Crypto)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Crypto indicates an expected call of Crypto
|
||||
func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Crypto", reflect.TypeOf((*MockAuthorizer)(nil).Crypto))
|
||||
}
|
||||
|
||||
// Decoder mocks base method
|
||||
func (m *MockAuthorizer) Decoder() *schema.Decoder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Decoder")
|
||||
ret0, _ := ret[0].(*schema.Decoder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Decoder indicates an expected call of Decoder
|
||||
func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decoder", reflect.TypeOf((*MockAuthorizer)(nil).Decoder))
|
||||
}
|
||||
|
||||
// Encoder mocks base method
|
||||
func (m *MockAuthorizer) Encoder() *schema.Encoder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Encoder")
|
||||
ret0, _ := ret[0].(*schema.Encoder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Encoder indicates an expected call of Encoder
|
||||
func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
|
||||
}
|
||||
|
||||
// Issuer mocks base method
|
||||
func (m *MockAuthorizer) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer
|
||||
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
|
||||
}
|
||||
|
||||
// Signer mocks base method
|
||||
func (m *MockAuthorizer) Signer() op.Signer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Signer")
|
||||
ret0, _ := ret[0].(op.Signer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Signer indicates an expected call of Signer
|
||||
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
|
||||
}
|
||||
|
||||
// Storage mocks base method
|
||||
func (m *MockAuthorizer) Storage() op.Storage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Storage")
|
||||
ret0, _ := ret[0].(op.Storage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Storage indicates an expected call of Storage
|
||||
func (mr *MockAuthorizerMockRecorder) Storage() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockAuthorizer)(nil).Storage))
|
||||
}
|
89
pkg/op/mock/authorizer.mock.impl.go
Normal file
89
pkg/op/mock/authorizer.mock.impl.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewAuthorizer(t *testing.T) op.Authorizer {
|
||||
return NewMockAuthorizer(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
|
||||
m := NewAuthorizer(t)
|
||||
ExpectDecoder(m)
|
||||
ExpectEncoder(m)
|
||||
ExpectSigner(m, t)
|
||||
ExpectStorage(m, t)
|
||||
// ExpectErrorHandler(m, t, wantErr)
|
||||
return m
|
||||
}
|
||||
|
||||
// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
|
||||
// m := NewAuthorizer(t)
|
||||
// ExpectDecoderFails(m)
|
||||
// ExpectEncoder(m)
|
||||
// ExpectSigner(m, t)
|
||||
// ExpectStorage(m, t)
|
||||
// ExpectErrorHandler(m, t)
|
||||
// return m
|
||||
// }
|
||||
|
||||
func ExpectDecoder(a op.Authorizer) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
|
||||
}
|
||||
|
||||
func ExpectEncoder(a op.Authorizer) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
|
||||
}
|
||||
|
||||
func ExpectSigner(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Signer().DoAndReturn(
|
||||
func() op.Signer {
|
||||
return &Sig{}
|
||||
})
|
||||
}
|
||||
|
||||
// func ExpectErrorHandler(a op.Authorizer, t *testing.T, wantErr bool) {
|
||||
// mockA := a.(*MockAuthorizer)
|
||||
// mockA.EXPECT().ErrorHandler().AnyTimes().
|
||||
// Return(func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// if wantErr {
|
||||
// require.Error(t, err)
|
||||
// return
|
||||
// }
|
||||
// require.NoError(t, err)
|
||||
// })
|
||||
// }
|
||||
|
||||
type Sig struct{}
|
||||
|
||||
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return jose.HS256
|
||||
}
|
||||
|
||||
func ExpectStorage(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
|
||||
}
|
||||
|
||||
// func NewMockSignerAny(t *testing.T) op.Signer {
|
||||
// m := NewMockSigner(gomock.NewController(t))
|
||||
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
|
||||
// return m
|
||||
// }
|
29
pkg/op/mock/client.go
Normal file
29
pkg/op/mock/client.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewClient(t *testing.T) op.Client {
|
||||
return NewMockClient(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
|
||||
c := NewClient(t)
|
||||
m := c.(*MockClient)
|
||||
m.EXPECT().RedirectURIs().AnyTimes().Return([]string{
|
||||
"https://registered.com/callback",
|
||||
"http://registered.com/callback",
|
||||
"http://localhost:9999/callback",
|
||||
"custom://callback"})
|
||||
m.EXPECT().ApplicationType().AnyTimes().Return(appType)
|
||||
m.EXPECT().LoginURL(gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(id string) string {
|
||||
return "login?id=" + id
|
||||
})
|
||||
return c
|
||||
}
|
147
pkg/op/mock/client.mock.go
Normal file
147
pkg/op/mock/client.mock.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Client)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
type MockClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientMockRecorder
|
||||
}
|
||||
|
||||
// MockClientMockRecorder is the mock recorder for MockClient
|
||||
type MockClientMockRecorder struct {
|
||||
mock *MockClient
|
||||
}
|
||||
|
||||
// NewMockClient creates a new mock instance
|
||||
func NewMockClient(ctrl *gomock.Controller) *MockClient {
|
||||
mock := &MockClient{ctrl: ctrl}
|
||||
mock.recorder = &MockClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AccessTokenLifetime mocks base method
|
||||
func (m *MockClient) AccessTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
|
||||
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
|
||||
}
|
||||
|
||||
// AccessTokenType mocks base method
|
||||
func (m *MockClient) AccessTokenType() op.AccessTokenType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenType")
|
||||
ret0, _ := ret[0].(op.AccessTokenType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenType indicates an expected call of AccessTokenType
|
||||
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||
}
|
||||
|
||||
// ApplicationType mocks base method
|
||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ApplicationType")
|
||||
ret0, _ := ret[0].(op.ApplicationType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ApplicationType indicates an expected call of ApplicationType
|
||||
func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||
}
|
||||
|
||||
// GetAuthMethod mocks base method
|
||||
func (m *MockClient) GetAuthMethod() op.AuthMethod {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthMethod")
|
||||
ret0, _ := ret[0].(op.AuthMethod)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthMethod indicates an expected call of GetAuthMethod
|
||||
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
|
||||
}
|
||||
|
||||
// GetID mocks base method
|
||||
func (m *MockClient) GetID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetID indicates an expected call of GetID
|
||||
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||
}
|
||||
|
||||
// IDTokenLifetime mocks base method
|
||||
func (m *MockClient) IDTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenLifetime indicates an expected call of IDTokenLifetime
|
||||
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
|
||||
}
|
||||
|
||||
// LoginURL mocks base method
|
||||
func (m *MockClient) LoginURL(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginURL", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LoginURL indicates an expected call of LoginURL
|
||||
func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
|
||||
}
|
||||
|
||||
// RedirectURIs mocks base method
|
||||
func (m *MockClient) RedirectURIs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RedirectURIs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RedirectURIs indicates an expected call of RedirectURIs
|
||||
func (mr *MockClientMockRecorder) RedirectURIs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockClient)(nil).RedirectURIs))
|
||||
}
|
132
pkg/op/mock/configuration.mock.go
Normal file
132
pkg/op/mock/configuration.mock.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Configuration)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockConfiguration is a mock of Configuration interface
|
||||
type MockConfiguration struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfigurationMockRecorder
|
||||
}
|
||||
|
||||
// MockConfigurationMockRecorder is the mock recorder for MockConfiguration
|
||||
type MockConfigurationMockRecorder struct {
|
||||
mock *MockConfiguration
|
||||
}
|
||||
|
||||
// NewMockConfiguration creates a new mock instance
|
||||
func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration {
|
||||
mock := &MockConfiguration{ctrl: ctrl}
|
||||
mock.recorder = &MockConfigurationMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AuthMethodPostSupported mocks base method
|
||||
func (m *MockConfiguration) AuthMethodPostSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthMethodPostSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported
|
||||
func (mr *MockConfigurationMockRecorder) AuthMethodPostSupported() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPostSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPostSupported))
|
||||
}
|
||||
|
||||
// AuthorizationEndpoint mocks base method
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
|
||||
}
|
||||
|
||||
// Issuer mocks base method
|
||||
func (m *MockConfiguration) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer
|
||||
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
|
||||
}
|
||||
|
||||
// KeysEndpoint mocks base method
|
||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// KeysEndpoint indicates an expected call of KeysEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
|
||||
}
|
||||
|
||||
// Port mocks base method
|
||||
func (m *MockConfiguration) Port() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Port")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Port indicates an expected call of Port
|
||||
func (mr *MockConfigurationMockRecorder) Port() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockConfiguration)(nil).Port))
|
||||
}
|
||||
|
||||
// TokenEndpoint mocks base method
|
||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TokenEndpoint indicates an expected call of TokenEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
|
||||
}
|
||||
|
||||
// UserinfoEndpoint mocks base method
|
||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserinfoEndpoint indicates an expected call of UserinfoEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) UserinfoEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserinfoEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserinfoEndpoint))
|
||||
}
|
7
pkg/op/mock/generate.go
Normal file
7
pkg/op/mock/generate.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/caos/oidc/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer
|
79
pkg/op/mock/signer.mock.go
Normal file
79
pkg/op/mock/signer.mock.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Signer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockSigner is a mock of Signer interface
|
||||
type MockSigner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSignerMockRecorder
|
||||
}
|
||||
|
||||
// MockSignerMockRecorder is the mock recorder for MockSigner
|
||||
type MockSignerMockRecorder struct {
|
||||
mock *MockSigner
|
||||
}
|
||||
|
||||
// NewMockSigner creates a new mock instance
|
||||
func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
|
||||
mock := &MockSigner{ctrl: ctrl}
|
||||
mock.recorder = &MockSignerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SignAccessToken mocks base method
|
||||
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignAccessToken indicates an expected call of SignAccessToken
|
||||
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
|
||||
}
|
||||
|
||||
// SignIDToken mocks base method
|
||||
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignIDToken", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignIDToken indicates an expected call of SignIDToken
|
||||
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
|
||||
}
|
||||
|
||||
// SignatureAlgorithm mocks base method
|
||||
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
||||
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm
|
||||
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
|
||||
}
|
170
pkg/op/mock/storage.mock.go
Normal file
170
pkg/op/mock/storage.mock.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Storage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockStorage is a mock of Storage interface
|
||||
type MockStorage struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStorageMockRecorder
|
||||
}
|
||||
|
||||
// MockStorageMockRecorder is the mock recorder for MockStorage
|
||||
type MockStorageMockRecorder struct {
|
||||
mock *MockStorage
|
||||
}
|
||||
|
||||
// NewMockStorage creates a new mock instance
|
||||
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
|
||||
mock := &MockStorage{ctrl: ctrl}
|
||||
mock.recorder = &MockStorageMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AuthRequestByID mocks base method
|
||||
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AuthRequestByID indicates an expected call of AuthRequestByID
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret mocks base method
|
||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// CreateAuthRequest mocks base method
|
||||
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateAuthRequest indicates an expected call of CreateAuthRequest
|
||||
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteAuthRequest mocks base method
|
||||
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
|
||||
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetClientByClientID mocks base method
|
||||
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.Client)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetClientByClientID indicates an expected call of GetClientByClientID
|
||||
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetKeySet mocks base method
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeySet indicates an expected call of GetKeySet
|
||||
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||
}
|
||||
|
||||
// GetSigningKey mocks base method
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSigningKey indicates an expected call of GetSigningKey
|
||||
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes mocks base method
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveKeyPair mocks base method
|
||||
func (m *MockStorage) SaveKeyPair(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveKeyPair", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SaveKeyPair indicates an expected call of SaveKeyPair
|
||||
func (mr *MockStorageMockRecorder) SaveKeyPair(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveKeyPair), arg0)
|
||||
}
|
142
pkg/op/mock/storage.mock.impl.go
Normal file
142
pkg/op/mock/storage.mock.impl.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewStorage(t *testing.T) op.Storage {
|
||||
return NewMockStorage(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewMockStorageExpectValidClientID(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectValidClientID(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageExpectInvalidClientID(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectInvalidClientID(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageAny(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
mockS := m.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
|
||||
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageSigningKeyError(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKeyError(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKeyInvalid(m)
|
||||
return m
|
||||
}
|
||||
func NewMockStorageSigningKey(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKey(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func ExpectInvalidClientID(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).Return(nil, errors.New("client not found"))
|
||||
}
|
||||
|
||||
func ExpectValidClientID(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, id string) (op.Client, error) {
|
||||
var appType op.ApplicationType
|
||||
var authMethod op.AuthMethod
|
||||
var accessTokenType op.AccessTokenType
|
||||
switch id {
|
||||
case "web_client":
|
||||
appType = op.ApplicationTypeWeb
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "native_client":
|
||||
appType = op.ApplicationTypeNative
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "useragent_client":
|
||||
appType = op.ApplicationTypeUserAgent
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeJWT
|
||||
}
|
||||
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func ExpectSigningKeyError(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(nil, errors.New("error"))
|
||||
}
|
||||
|
||||
func ExpectSigningKeyInvalid(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{}, nil)
|
||||
}
|
||||
|
||||
func ExpectSigningKey(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil)
|
||||
}
|
||||
|
||||
type ConfClient struct {
|
||||
id string
|
||||
appType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
accessTokenType op.AccessTokenType
|
||||
}
|
||||
|
||||
func (c *ConfClient) RedirectURIs() []string {
|
||||
return []string{
|
||||
"https://registered.com/callback",
|
||||
"http://registered.com/callback",
|
||||
"http://localhost:9999/callback",
|
||||
"custom://callback",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfClient) LoginURL(id string) string {
|
||||
return "login?id=" + id
|
||||
}
|
||||
|
||||
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||
return c.appType
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||
return c.accessTokenType
|
||||
}
|
51
pkg/op/op.go
Normal file
51
pkg/op/op.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type OpenIDProvider interface {
|
||||
Configuration
|
||||
HandleDiscovery(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorize(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
||||
HandleExchange(w http.ResponseWriter, r *http.Request)
|
||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||
HttpHandler() *http.Server
|
||||
}
|
||||
|
||||
func CreateRouter(o OpenIDProvider) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize)
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", o.HandleAuthorizeCallback)
|
||||
router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange)
|
||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
|
||||
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
||||
return router
|
||||
}
|
||||
|
||||
func Start(ctx context.Context, o OpenIDProvider) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := o.HttpHandler().Shutdown(ctx)
|
||||
if err != nil {
|
||||
logrus.Error("graceful shutdown of oidc server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := o.HttpHandler().ListenAndServe()
|
||||
if err != nil {
|
||||
logrus.Panicf("oidc server serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
logrus.Infof("oidc server is listening on %s", o.Port())
|
||||
}
|
13
pkg/op/session.go
Normal file
13
pkg/op/session.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package op
|
||||
|
||||
import "github.com/caos/oidc/pkg/oidc"
|
||||
|
||||
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
|
||||
if authRequest == nil {
|
||||
return true
|
||||
}
|
||||
if authRequest.Prompt == oidc.PromptNone {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
78
pkg/op/signer.go
Normal file
78
pkg/op/signer.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type Signer interface {
|
||||
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
||||
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
|
||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
type idTokenSigner struct {
|
||||
signer jose.Signer
|
||||
storage AuthStorage
|
||||
algorithm jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
|
||||
s := &idTokenSigner{
|
||||
storage: storage,
|
||||
}
|
||||
if err := s.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) initialize(ctx context.Context) error {
|
||||
var key *jose.SigningKey
|
||||
var err error
|
||||
key, err = s.storage.GetSigningKey(ctx)
|
||||
if err != nil {
|
||||
key, err = s.storage.SaveKeyPair(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.algorithm = key.Algorithm
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
|
||||
result, err := s.signer.Sign(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.CompactSerialize()
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return s.algorithm
|
||||
}
|
95
pkg/op/signer_test.go
Normal file
95
pkg/op/signer_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// func TestNewDefaultSigner(t *testing.T) {
|
||||
// type args struct {
|
||||
// storage Storage
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// args args
|
||||
// want Signer
|
||||
// wantErr bool
|
||||
// }{
|
||||
// {
|
||||
// "err initialize storage fails",
|
||||
// args{mock.NewMockStorageSigningKeyError(t)},
|
||||
// nil,
|
||||
// true,
|
||||
// },
|
||||
// {
|
||||
// "err initialize storage fails",
|
||||
// args{mock.NewMockStorageSigningKeyInvalid(t)},
|
||||
// nil,
|
||||
// true,
|
||||
// },
|
||||
// {
|
||||
// "initialize ok",
|
||||
// args{mock.NewMockStorageSigningKey(t)},
|
||||
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
|
||||
// false,
|
||||
// },
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// got, err := op.NewDefaultSigner(tt.args.storage)
|
||||
// if (err != nil) != tt.wantErr {
|
||||
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
// return
|
||||
// }
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func Test_idTokenSigner_Sign(t *testing.T) {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
signer jose.Signer
|
||||
storage Storage
|
||||
}
|
||||
type args struct {
|
||||
payload []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok",
|
||||
fields{signer, nil},
|
||||
args{[]byte("test")},
|
||||
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &idTokenSigner{
|
||||
signer: tt.fields.signer,
|
||||
storage: tt.fields.storage,
|
||||
}
|
||||
got, err := s.Sign(tt.args.payload)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
48
pkg/op/storage.go
Normal file
48
pkg/op/storage.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type AuthStorage interface {
|
||||
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
|
||||
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||
DeleteAuthRequest(context.Context, string) error
|
||||
|
||||
GetSigningKey(context.Context) (*jose.SigningKey, error)
|
||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||
SaveKeyPair(context.Context) (*jose.SigningKey, error)
|
||||
}
|
||||
|
||||
type OPStorage interface {
|
||||
GetClientByClientID(context.Context, string) (Client, error)
|
||||
AuthorizeClientIDSecret(context.Context, string, string) error
|
||||
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
AuthStorage
|
||||
OPStorage
|
||||
}
|
||||
|
||||
type AuthRequest interface {
|
||||
GetID() string
|
||||
GetACR() string
|
||||
GetAMR() []string
|
||||
GetAudience() []string
|
||||
GetAuthTime() time.Time
|
||||
GetClientID() string
|
||||
GetCodeChallenge() *oidc.CodeChallenge
|
||||
GetNonce() string
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetScopes() []string
|
||||
GetState() string
|
||||
GetSubject() string
|
||||
Done() bool
|
||||
}
|
93
pkg/op/token.go
Normal file
93
pkg/op/token.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenCreator interface {
|
||||
Issuer() string
|
||||
Signer() Signer
|
||||
Storage() Storage
|
||||
Crypto() Crypto
|
||||
}
|
||||
|
||||
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
|
||||
var accessToken string
|
||||
if createAccessToken {
|
||||
var err error
|
||||
accessToken, err = CreateAccessToken(authReq, client, creator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exp := uint64(client.AccessTokenLifetime().Seconds())
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
|
||||
if client.AccessTokenType() == AccessTokenTypeJWT {
|
||||
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
|
||||
}
|
||||
return CreateBearerToken(authReq, creator.Crypto())
|
||||
}
|
||||
|
||||
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
||||
|
||||
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
nbf := now
|
||||
exp := now.Add(client.AccessTokenLifetime())
|
||||
claims := &oidc.AccessTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: now,
|
||||
NotBefore: nbf,
|
||||
}
|
||||
return signer.SignAccessToken(claims)
|
||||
}
|
||||
|
||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||
var err error
|
||||
exp := time.Now().UTC().Add(validity)
|
||||
claims := &oidc.IDTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
AuthTime: authReq.GetAuthTime(),
|
||||
Nonce: authReq.GetNonce(),
|
||||
AuthenticationContextClassReference: authReq.GetACR(),
|
||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||
AuthorizedParty: authReq.GetClientID(),
|
||||
}
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if code != "" {
|
||||
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return signer.SignIDToken(claims)
|
||||
}
|
151
pkg/op/tokenrequest.go
Normal file
151
pkg/op/tokenrequest.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Exchanger interface {
|
||||
Issuer() string
|
||||
Storage() Storage
|
||||
Decoder() *schema.Decoder
|
||||
Signer() Signer
|
||||
Crypto() Crypto
|
||||
AuthMethodPostSupported() bool
|
||||
}
|
||||
|
||||
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
}
|
||||
if tokenReq.Code == "" {
|
||||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
||||
return
|
||||
}
|
||||
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
utils.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("error parsing form")
|
||||
}
|
||||
tokenReq := new(oidc.AccessTokenRequest)
|
||||
err = decoder.Decode(tokenReq, r.Form)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("error decoding form")
|
||||
}
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
tokenReq.ClientID = clientID
|
||||
tokenReq.ClientSecret = clientSecret
|
||||
|
||||
}
|
||||
return tokenReq, nil
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.GetID() != authReq.GetClientID() {
|
||||
return nil, nil, ErrInvalidRequest("invalid auth code")
|
||||
}
|
||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodNone {
|
||||
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger)
|
||||
return authReq, client, err
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||
return nil, nil, errors.New("basic not supported")
|
||||
}
|
||||
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
|
||||
}
|
||||
|
||||
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
if tokenReq.CodeVerifier == "" {
|
||||
return nil, ErrInvalidRequest("code_challenge required")
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
if !oidc.VerifyCodeChallenge(authReq.GetCodeChallenge(), tokenReq.CodeVerifier) {
|
||||
return nil, ErrInvalidRequest("code_challenge invalid")
|
||||
}
|
||||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
id, err := crypto.Decrypt(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.AuthRequestByID(ctx, id)
|
||||
}
|
||||
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
||||
return nil, errors.New("Unimplemented") //TODO: impl
|
||||
}
|
||||
|
||||
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
|
||||
return errors.New("Unimplemented") //TODO: impl
|
||||
}
|
28
pkg/op/userinfo.go
Normal file
28
pkg/op/userinfo.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type UserinfoProvider interface {
|
||||
Storage() Storage
|
||||
}
|
||||
|
||||
func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) {
|
||||
scopes, err := ScopesFromAccessToken(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
|
||||
if err != nil {
|
||||
utils.MarshalJSON(w, err)
|
||||
return
|
||||
}
|
||||
utils.MarshalJSON(w, info)
|
||||
}
|
||||
|
||||
func ScopesFromAccessToken(w http.ResponseWriter, r *http.Request) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
287
pkg/rp/default_rp.go
Normal file
287
pkg/rp/default_rp.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc/grants"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
grants_tx "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
idTokenKey = "id_token"
|
||||
stateParam = "state"
|
||||
pkceCode = "pkce"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
||||
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
|
||||
}
|
||||
)
|
||||
|
||||
//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface
|
||||
type DefaultRP struct {
|
||||
endpoints Endpoints
|
||||
|
||||
oauthConfig oauth2.Config
|
||||
config *Config
|
||||
pkce bool
|
||||
|
||||
httpClient *http.Client
|
||||
cookieHandler *utils.CookieHandler
|
||||
|
||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
verifier Verifier
|
||||
}
|
||||
|
||||
//NewDefaultRP creates `DefaultRP` with the given
|
||||
//Config and possible configOptions
|
||||
//it will run discovery on the provided issuer
|
||||
//if no verifier is provided using the options the `DefaultVerifier` is set
|
||||
func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExchangeRP, error) {
|
||||
p := &DefaultRP{
|
||||
config: rpConfig,
|
||||
httpClient: utils.DefaultHTTPClient,
|
||||
}
|
||||
|
||||
for _, optFunc := range rpOpts {
|
||||
optFunc(p)
|
||||
}
|
||||
|
||||
if err := p.discover(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.errorHandler == nil {
|
||||
p.errorHandler = DefaultErrorHandler
|
||||
}
|
||||
|
||||
if p.verifier == nil {
|
||||
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
//DefaultRPOpts is the type for providing dynamic options to the DefaultRP
|
||||
type DefaultRPOpts func(p *DefaultRP)
|
||||
|
||||
//WithCookieHandler set a `CookieHandler` for securing the various redirects
|
||||
func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.cookieHandler = cookieHandler
|
||||
}
|
||||
}
|
||||
|
||||
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
|
||||
//it also sets a `CookieHandler` for securing the various redirects
|
||||
//and exchanging the code challenge
|
||||
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.pkce = true
|
||||
p.cookieHandler = cookieHandler
|
||||
}
|
||||
}
|
||||
|
||||
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
|
||||
func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//wrapping the oauth2 `AuthCodeURL`
|
||||
//returning the url of the auth request
|
||||
func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
|
||||
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
authOpts = append(authOpts, opt()...)
|
||||
}
|
||||
return p.oauthConfig.AuthCodeURL(state, authOpts...)
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//extending the `AuthURL` method with a http redirect handler
|
||||
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
opts := make([]AuthURLOpt, 0)
|
||||
if err := p.trySetStateCookie(w, state); err != nil {
|
||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if p.pkce {
|
||||
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
}
|
||||
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
|
||||
var codeVerifier string
|
||||
codeVerifier = "s"
|
||||
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
codeOpts = append(codeOpts, opt()...)
|
||||
}
|
||||
|
||||
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
|
||||
if err != nil {
|
||||
return nil, err //TODO: our error
|
||||
}
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
//TODO: implement
|
||||
}
|
||||
|
||||
idToken, err := p.verifier.Verify(ctx, token.AccessToken, idTokenString)
|
||||
if err != nil {
|
||||
return nil, err //TODO: err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//extending the `CodeExchange` method with callback function
|
||||
func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := p.tryReadStateCookie(w, r)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
params := r.URL.Query()
|
||||
if params.Get("error") != "" {
|
||||
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||
return
|
||||
}
|
||||
codeOpts := make([]CodeExchangeOpt, 0)
|
||||
if p.pkce {
|
||||
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
||||
}
|
||||
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
callback(w, r, tokens, state)
|
||||
}
|
||||
}
|
||||
|
||||
// func (p *DefaultRP) Introspect(ctx context.Context, accessToken string) (oidc.TokenIntrospectResponse, error) {
|
||||
// // req := &http.Request{}
|
||||
// // resp, err := p.httpClient.Do(req)
|
||||
// // if err != nil {
|
||||
|
||||
// // }
|
||||
// // p.endpoints.IntrospectURL
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
func (p *DefaultRP) Userinfo() {}
|
||||
|
||||
//ClientCredentials is the `RelayingParty` interface implementation
|
||||
//handling the oauth2 client credentials grant
|
||||
func (p *DefaultRP) ClientCredentials(ctx context.Context, scopes ...string) (newToken *oauth2.Token, err error) {
|
||||
return p.callTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...))
|
||||
}
|
||||
|
||||
//TokenExchange is the `TokenExchangeRP` interface implementation
|
||||
//handling the oauth2 token exchange (draft)
|
||||
func (p *DefaultRP) TokenExchange(ctx context.Context, request *grants_tx.TokenExchangeRequest) (newToken *oauth2.Token, err error) {
|
||||
return p.callTokenEndpoint(request)
|
||||
}
|
||||
|
||||
//DelegationTokenExchange is the `TokenExchangeRP` interface implementation
|
||||
//handling the oauth2 token exchange for a delegation token (draft)
|
||||
func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken string, reqOpts ...grants_tx.TokenExchangeOption) (newToken *oauth2.Token, err error) {
|
||||
return p.TokenExchange(ctx, DelegationTokenRequest(subjectToken, reqOpts...))
|
||||
}
|
||||
|
||||
func (p *DefaultRP) discover() error {
|
||||
wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
discoveryConfig := new(oidc.DiscoveryConfiguration)
|
||||
err = utils.HttpRequest(p.httpClient, req, &discoveryConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.endpoints = GetEndpoints(discoveryConfig)
|
||||
p.oauthConfig = oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
Endpoint: p.endpoints.Endpoint,
|
||||
RedirectURL: p.config.CallbackURL,
|
||||
Scopes: p.config.Scopes,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Token, err error) {
|
||||
req, err := utils.FormRequest(p.endpoints.TokenURL, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(p.config.ClientID + ":" + p.config.ClientSecret))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
token := new(oauth2.Token)
|
||||
if err := utils.HttpRequest(p.httpClient, req, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
|
||||
if p.cookieHandler != nil {
|
||||
if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) tryReadStateCookie(w http.ResponseWriter, r *http.Request) (state string, err error) {
|
||||
if p.cookieHandler == nil {
|
||||
return r.FormValue(stateParam), nil
|
||||
}
|
||||
state, err = p.cookieHandler.CheckQueryCookie(r, stateParam)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
p.cookieHandler.DeleteCookie(w, stateParam)
|
||||
return state, nil
|
||||
}
|
363
pkg/rp/default_verifier.go
Normal file
363
pkg/rp/default_verifier.go
Normal file
|
@ -0,0 +1,363 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
//DefaultVerifier implements the `Verifier` interface
|
||||
type DefaultVerifier struct {
|
||||
config *verifierConfig
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
//ConfFunc is the type for providing dynamic options to the DefaultVerfifier
|
||||
type ConfFunc func(*verifierConfig)
|
||||
|
||||
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
||||
type ACRVerifier func(string) error
|
||||
|
||||
//NewDefaultVerifier creates `DefaultVerifier` with the given
|
||||
//issuer, clientID, keyset and possible configOptions
|
||||
func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier {
|
||||
conf := &verifierConfig{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
iat: &iatConfig{
|
||||
// offset: time.Duration(500 * time.Millisecond),
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range confOpts {
|
||||
if opt != nil {
|
||||
opt(conf)
|
||||
}
|
||||
}
|
||||
return &DefaultVerifier{config: conf, keySet: keySet}
|
||||
}
|
||||
|
||||
//WithIgnoreIssuedAt will turn off iat claim verification
|
||||
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.ignore = true
|
||||
}
|
||||
}
|
||||
|
||||
//WithIssuedAtOffset mitigates the risk of iat to be in the future
|
||||
//because of clock skews with the ability to add an offset to the current time
|
||||
func WithIssuedAtOffset(offset time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
//WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
//WithNonce TODO: ?
|
||||
func WithNonce(nonce string) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.nonce = nonce
|
||||
}
|
||||
}
|
||||
|
||||
//WithACRVerifier sets the verifier for the acr claim
|
||||
func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.acr = verifier
|
||||
}
|
||||
}
|
||||
|
||||
//WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
|
||||
func WithAuthTimeMaxAge(maxAge time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
//WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
|
||||
func WithSupportedSigningAlgorithms(algs ...string) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.supportedSignAlgs = algs
|
||||
}
|
||||
}
|
||||
|
||||
type verifierConfig struct {
|
||||
issuer string
|
||||
clientID string
|
||||
nonce string
|
||||
iat *iatConfig
|
||||
acr ACRVerifier
|
||||
maxAge time.Duration
|
||||
supportedSignAlgs []string
|
||||
|
||||
// httpClient *http.Client
|
||||
|
||||
now time.Time
|
||||
}
|
||||
|
||||
type iatConfig struct {
|
||||
ignore bool
|
||||
offset time.Duration
|
||||
maxAge time.Duration
|
||||
}
|
||||
|
||||
//DefaultACRVerifier implements `ACRVerifier` returning an error
|
||||
//if non of the provided values matches the acr claim
|
||||
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
|
||||
return func(acr string) error {
|
||||
if !utils.Contains(possibleValues, acr) {
|
||||
return ErrAcrInvalid(possibleValues, acr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
//Verify implements the `Verify` method of the `Verifier` interface
|
||||
//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
|
||||
v.config.now = time.Now().UTC()
|
||||
idToken, err := v.VerifyIDToken(ctx, idTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
|
||||
return nil, err
|
||||
}
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) now() time.Time {
|
||||
if v.config.now.IsZero() {
|
||||
v.config.now = time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
return v.config.now
|
||||
}
|
||||
|
||||
//VerifyIDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
|
||||
//1. if encrypted --> decrypt
|
||||
decrypted, err := v.decryptToken(idTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, payload, err := v.parseToken(decrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
//2, check issuer (exact match)
|
||||
if err := v.checkIssuer(claims.Issuer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//3. check aud (aud must contain client_id, all aud strings must be allowed)
|
||||
if err = v.checkAudience(claims.Audiences); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//6. check signature by keys
|
||||
//7. check alg default is rs256
|
||||
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
|
||||
claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//9. check exp before now
|
||||
if err = v.checkExpiration(claims.Expiration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//10. check iat duration is optional (can be checked)
|
||||
if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
|
||||
if err = v.checkNonce(claims.Nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//12. if acr requested check acr
|
||||
if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//13. if auth_time requested check if auth_time is less than max age
|
||||
if err = v.checkAuthTime(claims.AuthTime); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
idToken := new(oidc.IDTokenClaims)
|
||||
err = json.Unmarshal(payload, idToken)
|
||||
return idToken, payload, err
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkIssuer(issuer string) error {
|
||||
if v.config.issuer != issuer {
|
||||
return ErrIssuerInvalid(v.config.issuer, issuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkAudience(audiences []string) error {
|
||||
if !utils.Contains(audiences, v.config.clientID) {
|
||||
return ErrAudienceMissingClientID(v.config.clientID)
|
||||
}
|
||||
|
||||
//TODO: check aud trusted
|
||||
return nil
|
||||
}
|
||||
|
||||
//4. if multiple aud strings --> check if azp
|
||||
//5. if azp --> check azp == client_id
|
||||
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
|
||||
if len(audiences) > 1 {
|
||||
if authorizedParty == "" {
|
||||
return ErrAzpMissing()
|
||||
}
|
||||
}
|
||||
if authorizedParty != "" && authorizedParty != v.config.clientID {
|
||||
return ErrAzpInvalid(authorizedParty, v.config.clientID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) {
|
||||
jws, err := jose.ParseSigned(idTokenString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
return "", nil //TODO: error
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
return "", nil //TODO: error
|
||||
}
|
||||
sig := jws.Signatures[0]
|
||||
supportedSigAlgs := v.config.supportedSignAlgs
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{"RS256"}
|
||||
}
|
||||
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
|
||||
return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
|
||||
}
|
||||
|
||||
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
|
||||
if err != nil {
|
||||
return "", err
|
||||
//TODO:
|
||||
}
|
||||
|
||||
if !bytes.Equal(signedPayload, payload) {
|
||||
return "", ErrSignatureInvalidPayload() //TODO: err
|
||||
}
|
||||
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
|
||||
expiration = expiration.Round(time.Second)
|
||||
if !v.now().Before(expiration) {
|
||||
return ErrExpInvalid(expiration)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error {
|
||||
if v.config.iat.ignore {
|
||||
return nil
|
||||
}
|
||||
issuedAt = issuedAt.Round(time.Second)
|
||||
offset := v.now().Add(v.config.iat.offset).Round(time.Second)
|
||||
if issuedAt.After(offset) {
|
||||
return ErrIatInFuture(issuedAt, offset)
|
||||
}
|
||||
if v.config.iat.maxAge == 0 {
|
||||
return nil
|
||||
}
|
||||
maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
|
||||
if issuedAt.Before(maxAge) {
|
||||
return ErrIatToOld(maxAge, issuedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkNonce(nonce string) error {
|
||||
if v.config.nonce == "" {
|
||||
return nil
|
||||
}
|
||||
if v.config.nonce != nonce {
|
||||
return ErrNonceInvalid(v.config.nonce, nonce)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error {
|
||||
if v.config.acr != nil {
|
||||
return v.config.acr(acr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error {
|
||||
if v.config.maxAge == 0 {
|
||||
return nil
|
||||
}
|
||||
if authTime.IsZero() {
|
||||
return ErrAuthTimeNotPresent()
|
||||
}
|
||||
authTime = authTime.Round(time.Second)
|
||||
maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
|
||||
if authTime.Before(maxAge) {
|
||||
return ErrAuthTimeToOld(maxAge, authTime)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) {
|
||||
return tokenString, nil //TODO: impl
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
if atHash == "" {
|
||||
return nil //TODO: return error
|
||||
}
|
||||
|
||||
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if actual != atHash {
|
||||
return nil //TODO: error
|
||||
}
|
||||
return nil
|
||||
}
|
13
pkg/rp/delegation.go
Normal file
13
pkg/rp/delegation.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
//DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional
|
||||
//"urn:ietf:params:oauth:token-type:access_token" actor token for a
|
||||
//"urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
|
||||
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
|
||||
}
|
58
pkg/rp/error.go
Normal file
58
pkg/rp/error.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIssuerInvalid = func(expected, actual string) *validationError {
|
||||
return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
|
||||
}
|
||||
ErrAudienceMissingClientID = func(clientID string) *validationError {
|
||||
return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
|
||||
}
|
||||
ErrAzpMissing = func() *validationError {
|
||||
return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
|
||||
}
|
||||
ErrAzpInvalid = func(azp, clientID string) *validationError {
|
||||
return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
|
||||
}
|
||||
ErrExpInvalid = func(exp time.Time) *validationError {
|
||||
return ValidationError("Token has expired %v", exp)
|
||||
}
|
||||
ErrIatInFuture = func(exp, now time.Time) *validationError {
|
||||
return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
|
||||
}
|
||||
ErrIatToOld = func(maxAge, iat time.Time) *validationError {
|
||||
return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
|
||||
}
|
||||
ErrNonceInvalid = func(expected, actual string) *validationError {
|
||||
return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
|
||||
}
|
||||
ErrAcrInvalid = func(expected []string, actual string) *validationError {
|
||||
return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
|
||||
}
|
||||
|
||||
ErrAuthTimeNotPresent = func() *validationError {
|
||||
return ValidationError("claim `auth_time` of token is missing")
|
||||
}
|
||||
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
|
||||
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
|
||||
}
|
||||
ErrSignatureInvalidPayload = func() *validationError {
|
||||
return ValidationError("Signature does not match Payload")
|
||||
}
|
||||
)
|
||||
|
||||
func ValidationError(message string, args ...interface{}) *validationError {
|
||||
return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
|
||||
}
|
||||
|
||||
type validationError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (v *validationError) Error() string {
|
||||
return v.message
|
||||
}
|
166
pkg/rp/jwks.go
Normal file
166
pkg/rp/jwks.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
|
||||
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
|
||||
}
|
||||
|
||||
type remoteKeySet struct {
|
||||
jwksURL string
|
||||
httpClient *http.Client
|
||||
|
||||
// guard all other fields
|
||||
mu sync.Mutex
|
||||
|
||||
// inflight suppresses parallel execution of updateKeys and allows
|
||||
// multiple goroutines to wait for its result.
|
||||
inflight *inflight
|
||||
|
||||
// A set of cached keys and their expiry.
|
||||
cachedKeys []jose.JSONWebKey
|
||||
}
|
||||
|
||||
// inflight is used to wait on some in-flight request from multiple goroutines.
|
||||
type inflight struct {
|
||||
doneCh chan struct{}
|
||||
|
||||
keys []jose.JSONWebKey
|
||||
err error
|
||||
}
|
||||
|
||||
func newInflight() *inflight {
|
||||
return &inflight{doneCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
// wait returns a channel that multiple goroutines can receive on. Once it returns
|
||||
// a value, the inflight request is done and result() can be inspected.
|
||||
func (i *inflight) wait() <-chan struct{} {
|
||||
return i.doneCh
|
||||
}
|
||||
|
||||
// done can only be called by a single goroutine. It records the result of the
|
||||
// inflight request and signals other goroutines that the result is safe to
|
||||
// inspect.
|
||||
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
|
||||
i.keys = keys
|
||||
i.err = err
|
||||
close(i.doneCh)
|
||||
}
|
||||
|
||||
// result cannot be called until the wait() channel has returned a value.
|
||||
func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
||||
return i.keys, i.err
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
// We don't support JWTs signed with multiple signatures.
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
}
|
||||
|
||||
keys := r.keysFromCache()
|
||||
payload, err, ok := checkKey(keyID, keys, jws)
|
||||
if ok {
|
||||
return payload, err
|
||||
}
|
||||
|
||||
keys, err = r.keysFromRemote(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching keys %v", err)
|
||||
}
|
||||
|
||||
payload, err, ok = checkKey(keyID, keys, jws)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid kid")
|
||||
}
|
||||
return payload, err
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cachedKeys
|
||||
}
|
||||
|
||||
// keysFromRemote syncs the key set from the remote set, records the values in the
|
||||
// cache, and returns the key set.
|
||||
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||
// Need to lock to inspect the inflight request field.
|
||||
r.mu.Lock()
|
||||
// If there's not a current inflight request, create one.
|
||||
if r.inflight == nil {
|
||||
r.inflight = newInflight()
|
||||
|
||||
// This goroutine has exclusive ownership over the current inflight
|
||||
// request. It releases the resource by nil'ing the inflight field
|
||||
// once the goroutine is done.
|
||||
go r.updateKeys(ctx)
|
||||
}
|
||||
inflight := r.inflight
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-inflight.wait():
|
||||
return inflight.result()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
||||
// Sync keys and finish inflight when that's done.
|
||||
keys, err := r.fetchRemoteKeys(ctx)
|
||||
|
||||
r.inflight.done(keys, err)
|
||||
|
||||
// Lock to update the keys and indicate that there is no longer an
|
||||
// inflight request.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
r.cachedKeys = keys
|
||||
}
|
||||
|
||||
// Free inflight so a different request can run.
|
||||
r.inflight = nil
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
||||
}
|
||||
|
||||
keySet := new(jose.JSONWebKeySet)
|
||||
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
|
||||
}
|
||||
|
||||
return keySet.Keys, nil
|
||||
}
|
||||
|
||||
func checkKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) {
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err, true
|
||||
}
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
105
pkg/rp/relaying_party.go
Normal file
105
pkg/rp/relaying_party.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
//RelayingParty declares the minimal interface for oidc clients
|
||||
type RelayingParty interface {
|
||||
|
||||
//AuthURL returns the authorization endpoint with a given state
|
||||
AuthURL(state string, opts ...AuthURLOpt) string
|
||||
|
||||
//AuthURLHandler should implement the AuthURL func as http.HandlerFunc
|
||||
//(redirecting to the auth endpoint)
|
||||
AuthURLHandler(state string) http.HandlerFunc
|
||||
|
||||
//CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
|
||||
//returning an `Access Token` and `ID Token Claims`
|
||||
CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error)
|
||||
|
||||
//CodeExchangeHandler extends the CodeExchange func,
|
||||
//calling the provided callback func on success with additional returned `state`
|
||||
CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc
|
||||
|
||||
//ClientCredentials implements the oauth2 Client Credentials Grant
|
||||
//requesting an `Access Token` for the client itself, without user context
|
||||
ClientCredentials(ctx context.Context, scopes ...string) (*oauth2.Token, error)
|
||||
|
||||
//Introspects calls the Introspect Endpoint
|
||||
//for validating an (access) token
|
||||
// Introspect(ctx context.Context, token string) (TokenIntrospectResponse, error)
|
||||
|
||||
//Userinfo implements the OIDC Userinfo call
|
||||
//returning the info of the user for the requested scopes of an access token
|
||||
Userinfo()
|
||||
}
|
||||
|
||||
//PasswortGrantRP extends the `RelayingParty` interface with the oauth2 `Password Grant`
|
||||
//
|
||||
//This interface is separated from the standard `RelayingParty` interface as the `password grant`
|
||||
//is part of the oauth2 and therefore OIDC specification, but should only be used when there's no
|
||||
//other possibility, so IMHO never ever. Ever.
|
||||
type PasswortGrantRP interface {
|
||||
RelayingParty
|
||||
|
||||
//PasswordGrant implements the oauth2 `Password Grant`,
|
||||
//requesting an access token with the users `username` and `password`
|
||||
PasswordGrant(context.Context, string, string) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
CallbackURL string
|
||||
Issuer string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type OptionFunc func(RelayingParty)
|
||||
|
||||
type Endpoints struct {
|
||||
oauth2.Endpoint
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
}
|
||||
|
||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
return Endpoints{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: discoveryConfig.AuthorizationEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthURLOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
//WithCodeChallenge sets the `code_challenge` params in the auth request
|
||||
func WithCodeChallenge(codeChallenge string) AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
//WithCodeVerifier sets the `code_verifier` param in the token request
|
||||
func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
|
||||
}
|
||||
}
|
27
pkg/rp/tockenexchange.go
Normal file
27
pkg/rp/tockenexchange.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
//TokenExchangeRP extends the `RelayingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
type TokenExchangeRP interface {
|
||||
RelayingParty
|
||||
|
||||
//TokenExchange implement the `Token Echange Grant` exchanging some token for an other
|
||||
TokenExchange(context.Context, *tokenexchange.TokenExchangeRequest) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
//DelegationTokenExchangeRP extends the `TokenExchangeRP` interface
|
||||
//for the specific `delegation token` request
|
||||
type DelegationTokenExchangeRP interface {
|
||||
TokenExchangeRP
|
||||
|
||||
//DelegationTokenExchange implement the `Token Exchange Grant`
|
||||
//providing an access token in request for a `delegation` token for a given resource / audience
|
||||
DelegationTokenExchange(context.Context, string, ...tokenexchange.TokenExchangeOption) (*oauth2.Token, error)
|
||||
}
|
15
pkg/rp/verifier.go
Normal file
15
pkg/rp/verifier.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
//Verifier implement the Token Response Validation as defined in OIDC specification
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
type Verifier interface {
|
||||
|
||||
//Verify checks the access_token and id_token and returns the `id token claims`
|
||||
Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error)
|
||||
}
|
110
pkg/utils/cookie.go
Normal file
110
pkg/utils/cookie.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
)
|
||||
|
||||
type CookieHandler struct {
|
||||
securecookie *securecookie.SecureCookie
|
||||
secureOnly bool
|
||||
sameSite http.SameSite
|
||||
maxAge int
|
||||
domain string
|
||||
}
|
||||
|
||||
func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *CookieHandler {
|
||||
c := &CookieHandler{
|
||||
securecookie: securecookie.New(hashKey, encryptKey),
|
||||
secureOnly: true,
|
||||
sameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type CookieHandlerOpt func(*CookieHandler)
|
||||
|
||||
func WithUnsecure() CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.secureOnly = false
|
||||
}
|
||||
}
|
||||
|
||||
func WithSameSite(sameSite http.SameSite) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.sameSite = sameSite
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxAge(maxAge int) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.maxAge = maxAge
|
||||
c.securecookie.MaxAge(maxAge)
|
||||
}
|
||||
}
|
||||
|
||||
func WithDomain(domain string) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.domain = domain
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var value string
|
||||
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
|
||||
value, err := c.CheckCookie(r, name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if value != r.FormValue(name) {
|
||||
return "", errors.New(name + " does not compare")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
|
||||
encoded, err := c.securecookie.Encode(name, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: encoded,
|
||||
Domain: c.domain,
|
||||
Path: "/",
|
||||
MaxAge: c.maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: c.secureOnly,
|
||||
SameSite: c.sameSite,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Domain: c.domain,
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: c.secureOnly,
|
||||
SameSite: c.sameSite,
|
||||
})
|
||||
}
|
70
pkg/utils/crypto.go
Normal file
70
pkg/utils/crypto.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
func EncryptAES(data string, key string) (string, error) {
|
||||
encrypted, err := EncryptBytesAES([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func EncryptBytesAES(plainText []byte, key string) ([]byte, error) {
|
||||
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cipherText := make([]byte, aes.BlockSize+len(plainText))
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)
|
||||
|
||||
return cipherText, nil
|
||||
}
|
||||
|
||||
func DecryptAES(data string, key string) (string, error) {
|
||||
text, err := base64.URLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
decrypted, err := DecryptBytesAES(text, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decrypted), nil
|
||||
}
|
||||
|
||||
func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) {
|
||||
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
err = errors.New("Ciphertext block size is too short!")
|
||||
return nil, err
|
||||
}
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
cipherText = cipherText[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText, cipherText)
|
||||
|
||||
return cipherText, err
|
||||
}
|
30
pkg/utils/hash.go
Normal file
30
pkg/utils/hash.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
||||
switch sigAlgorithm {
|
||||
case jose.RS256, jose.ES256, jose.PS256:
|
||||
return sha256.New(), nil
|
||||
case jose.RS384, jose.ES384, jose.PS384:
|
||||
return sha512.New384(), nil
|
||||
case jose.RS512, jose.ES512, jose.PS512:
|
||||
return sha512.New(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
||||
}
|
||||
}
|
||||
|
||||
func HashString(hash hash.Hash, s string) string {
|
||||
hash.Write([]byte(s)) // hash documents that Write will never return an error
|
||||
sum := hash.Sum(nil)[:hash.Size()/2]
|
||||
return base64.RawURLEncoding.EncodeToString(sum)
|
||||
}
|
67
pkg/utils/http.go
Normal file
67
pkg/utils/http.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultHTTPClient = &http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
)
|
||||
|
||||
func FormRequest(endpoint string, request interface{}) (*http.Request, error) {
|
||||
form := make(map[string][]string)
|
||||
encoder := schema.NewEncoder()
|
||||
if err := encoder.Encode(request, form); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := strings.NewReader(url.Values(form).Encode())
|
||||
req, err := http.NewRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func HttpRequest(client *http.Client, req *http.Request, response interface{}) error {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("http status not ok: %s %s", resp.Status, body)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %v %s", err, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
|
||||
values := make(map[string][]string)
|
||||
err := encoder.Encode(resp, values)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
v := url.Values(values)
|
||||
return v.Encode(), nil
|
||||
}
|
21
pkg/utils/marshal.go
Normal file
21
pkg/utils/marshal.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func MarshalJSON(w http.ResponseWriter, i interface{}) {
|
||||
b, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("content-type", "application/json")
|
||||
_, err = w.Write(b)
|
||||
if err != nil {
|
||||
logrus.Error("error writing response")
|
||||
}
|
||||
}
|
10
pkg/utils/strings.go
Normal file
10
pkg/utils/strings.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package utils
|
||||
|
||||
func Contains(list []string, needle string) bool {
|
||||
for _, item := range list {
|
||||
if item == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
48
pkg/utils/strings_test.go
Normal file
48
pkg/utils/strings_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
type args struct {
|
||||
list []string
|
||||
needle string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"empty list false",
|
||||
args{[]string{}, "needle"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list not containing false",
|
||||
args{[]string{"list"}, "needle"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list not containing empty needle false",
|
||||
args{[]string{"list", "needle"}, ""},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list containing true",
|
||||
args{[]string{"list", "needle"}, "needle"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"list containing empty needle true",
|
||||
args{[]string{"list", "needle", ""}, ""},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Contains(tt.args.list, tt.args.needle); got != tt.want {
|
||||
t.Errorf("Contains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue