refactor: use struct types for claim related types

BREAKING change.
The following types are changed from interface to struct type:

- AccessTokenClaims
- IDTokenClaims
- IntrospectionResponse
- UserInfo and related types.

The following methods of OPStorage now take a pointer to a struct type,
instead of an interface:

- SetUserinfoFromScopes
- SetUserinfoFromToken
- SetIntrospectionFromToken

The following functions are now generic, so that type-safe extension
of Claims is now possible:

- op.VerifyIDTokenHint
- op.VerifyAccessToken
- rp.VerifyTokens
- rp.VerifyIDToken
This commit is contained in:
Tim Möhlmann 2023-02-17 16:50:28 +02:00
parent 11682a2cc8
commit 85bd99873d
40 changed files with 857 additions and 1291 deletions

View file

@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: ['1.16', '1.17', '1.18', '1.19', '1.20'] go: ['1.18', '1.19', '1.20']
name: Go ${{ matrix.go }} test name: Go ${{ matrix.go }} test
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

View file

@ -98,9 +98,7 @@ Versions that also build are marked with :warning:.
| Version | Supported | | Version | Supported |
|---------|--------------------| |---------|--------------------|
| <1.16 | :x: | | <1.18 | :x: |
| 1.16 | :warning: |
| 1.17 | :warning: |
| 1.18 | :warning: | | 1.18 | :warning: |
| 1.19 | :white_check_mark: | | 1.19 | :white_check_mark: |
| 1.20 | :white_check_mark: | | 1.20 | :white_check_mark: |

View file

@ -76,7 +76,7 @@ func main() {
params := mux.Vars(r) params := mux.Vars(r)
requestedClaim := params["claim"] requestedClaim := params["claim"]
requestedValue := params["value"] requestedValue := params["value"]
value, ok := resp.GetClaim(requestedClaim).(string) value, ok := resp.Claims[requestedClaim].(string)
if !ok || value == "" || value != requestedValue { if !ok || value == "" || value != requestedValue {
http.Error(w, "claim does not match", http.StatusForbidden) http.Error(w, "claim does not match", http.StatusForbidden)
return return

View file

@ -60,7 +60,7 @@ func main() {
http.Handle("/login", rp.AuthURLHandler(state, provider)) http.Handle("/login", rp.AuthURLHandler(state, provider))
// for demonstration purposes the returned userinfo response is written as JSON object onto response // for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info *oidc.UserInfo) {
data, err := json.Marshal(info) data, err := json.Marshal(info)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View file

@ -429,13 +429,13 @@ func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientS
// SetUserinfoFromScopes implements the op.Storage interface // SetUserinfoFromScopes implements the op.Storage interface
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check // it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error { func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
return s.setUserinfo(ctx, userinfo, userID, clientID, scopes) return s.setUserinfo(ctx, userinfo, userID, clientID, scopes)
} }
// SetUserinfoFromToken implements the op.Storage interface // SetUserinfoFromToken implements the op.Storage interface
// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function // it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error { func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
token, ok := func() (*Token, bool) { token, ok := func() (*Token, bool) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -463,7 +463,7 @@ func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserIn
// SetIntrospectionFromToken implements the op.Storage interface // SetIntrospectionFromToken implements the op.Storage interface
// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function // it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function
func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection oidc.IntrospectionResponse, tokenID, subject, clientID string) error { func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
token, ok := func() (*Token, bool) { token, ok := func() (*Token, bool) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -480,14 +480,17 @@ func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection o
// this will automatically be done by the library if you don't return an error // this will automatically be done by the library if you don't return an error
// you can also return further information about the user / associated token // you can also return further information about the user / associated token
// e.g. the userinfo (equivalent to userinfo endpoint) // e.g. the userinfo (equivalent to userinfo endpoint)
err := s.setUserinfo(ctx, introspection, subject, clientID, token.Scopes)
userInfo := new(oidc.UserInfo)
err := s.setUserinfo(ctx, userInfo, subject, clientID, token.Scopes)
if err != nil { if err != nil {
return err return err
} }
introspection.SetUserInfo(userInfo)
//...and also the requested scopes... //...and also the requested scopes...
introspection.SetScopes(token.Scopes) introspection.Scope = token.Scopes
//...and the client the token was issued to //...and the client the token was issued to
introspection.SetClientID(token.ApplicationID) introspection.ClientID = token.ApplicationID
return nil return nil
} }
} }
@ -608,7 +611,7 @@ func (s *Storage) accessToken(applicationID, refreshTokenID, subject string, aud
} }
// setUserinfo sets the info based on the user, scopes and if necessary the clientID // setUserinfo sets the info based on the user, scopes and if necessary the clientID
func (s *Storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter, userID, clientID string, scopes []string) (err error) { func (s *Storage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, userID, clientID string, scopes []string) (err error) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
user := s.userStore.GetUserByID(userID) user := s.userStore.GetUserByID(userID)
@ -618,17 +621,20 @@ func (s *Storage) setUserinfo(ctx context.Context, userInfo oidc.UserInfoSetter,
for _, scope := range scopes { for _, scope := range scopes {
switch scope { switch scope {
case oidc.ScopeOpenID: case oidc.ScopeOpenID:
userInfo.SetSubject(user.ID) userInfo.Subject = user.ID
case oidc.ScopeEmail: case oidc.ScopeEmail:
userInfo.SetEmail(user.Email, user.EmailVerified) userInfo.Email = user.Email
userInfo.EmailVerified = oidc.Bool(user.EmailVerified)
//user.Email, user.EmailVerified
case oidc.ScopeProfile: case oidc.ScopeProfile:
userInfo.SetPreferredUsername(user.Username) userInfo.PreferredUsername = user.Username
userInfo.SetName(user.FirstName + " " + user.LastName) userInfo.Name = user.FirstName + " " + user.LastName
userInfo.SetFamilyName(user.LastName) userInfo.FamilyName = user.LastName
userInfo.SetGivenName(user.FirstName) userInfo.GivenName = user.FirstName
userInfo.SetLocale(user.PreferredLanguage) userInfo.Locale = oidc.NewLocale(user.PreferredLanguage)
case oidc.ScopePhone: case oidc.ScopePhone:
userInfo.SetPhone(user.Phone, user.PhoneVerified) userInfo.PhoneNumber = user.Phone
userInfo.PhoneNumberVerified = user.PhoneVerified
case CustomScope: case CustomScope:
// you can also have a custom scope and assert public or custom claims based on that // you can also have a custom scope and assert public or custom claims based on that
userInfo.AppendClaims(CustomClaim, customClaim(clientID)) userInfo.AppendClaims(CustomClaim, customClaim(clientID))
@ -698,7 +704,7 @@ func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context,
// SetUserinfoFromScopesForTokenExchange implements the op.TokenExchangeStorage interface // SetUserinfoFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
// it will be called for the creation of an id_token - we are using the same private function as for other flows, // it will be called for the creation of an id_token - we are using the same private function as for other flows,
// plus adding token exchange specific claims related to delegation or impersonation // plus adding token exchange specific claims related to delegation or impersonation
func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo oidc.UserInfoSetter, request op.TokenExchangeRequest) error { func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request op.TokenExchangeRequest) error {
err := s.setUserinfo(ctx, userinfo, request.GetSubject(), request.GetClientID(), request.GetScopes()) err := s.setUserinfo(ctx, userinfo, request.GetSubject(), request.GetClientID(), request.GetScopes())
if err != nil { if err != nil {
return err return err

View file

@ -198,7 +198,7 @@ func (s *multiStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, cl
// SetUserinfoFromScopes implements the op.Storage interface // SetUserinfoFromScopes implements the op.Storage interface
// it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check // it will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error { func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
storage, err := s.storageFromContext(ctx) storage, err := s.storageFromContext(ctx)
if err != nil { if err != nil {
return err return err
@ -208,7 +208,7 @@ func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.
// SetUserinfoFromToken implements the op.Storage interface // SetUserinfoFromToken implements the op.Storage interface
// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function // it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error { func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
storage, err := s.storageFromContext(ctx) storage, err := s.storageFromContext(ctx)
if err != nil { if err != nil {
return err return err
@ -218,7 +218,7 @@ func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.U
// SetIntrospectionFromToken implements the op.Storage interface // SetIntrospectionFromToken implements the op.Storage interface
// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function // it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function
func (s *multiStorage) SetIntrospectionFromToken(ctx context.Context, introspection oidc.IntrospectionResponse, tokenID, subject, clientID string) error { func (s *multiStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
storage, err := s.storageFromContext(ctx) storage, err := s.storageFromContext(ctx)
if err != nil { if err != nil {
return err return err

21
go.mod
View file

@ -1,22 +1,35 @@
module github.com/zitadel/oidc/v2 module github.com/zitadel/oidc/v2
go 1.16 go 1.18
require ( require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.2 // indirect
github.com/google/go-github/v31 v31.0.0 github.com/google/go-github/v31 v31.0.0
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/schema v1.2.0 github.com/gorilla/schema v1.2.0
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7 github.com/jeremija/gosubmit v0.2.7
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/rs/cors v1.8.3 github.com/rs/cors v1.8.3
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43
golang.org/x/text v0.6.0 golang.org/x/text v0.6.0
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/golang/protobuf v1.4.2 // indirect
github.com/google/go-cmp v0.5.2 // indirect
github.com/google/go-querystring v1.0.0 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect
golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect
google.golang.org/appengine v1.6.6 // indirect
google.golang.org/protobuf v1.25.0 // indirect
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

9
go.sum
View file

@ -146,7 +146,6 @@ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
@ -190,7 +189,6 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -217,7 +215,6 @@ golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@ -237,7 +234,6 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -266,19 +262,15 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -325,7 +317,6 @@ golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View file

@ -176,8 +176,8 @@ func SignedJWTProfileAssertion(clientID string, audience []string, expiration ti
Issuer: clientID, Issuer: clientID,
Subject: clientID, Subject: clientID,
Audience: audience, Audience: audience,
ExpiresAt: oidc.Time(exp), ExpiresAt: oidc.FromTime(exp),
IssuedAt: oidc.Time(iat), IssuedAt: oidc.FromTime(iat),
}, signer) }, signer)
} }

View file

@ -238,19 +238,19 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
} }
var email string var email string
redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info *oidc.UserInfo) {
require.NotNil(t, tokens, "tokens") require.NotNil(t, tokens, "tokens")
require.NotNil(t, info, "info") require.NotNil(t, info, "info")
t.Log("access token", tokens.AccessToken) t.Log("access token", tokens.AccessToken)
t.Log("refresh token", tokens.RefreshToken) t.Log("refresh token", tokens.RefreshToken)
t.Log("id token", tokens.IDToken) t.Log("id token", tokens.IDToken)
t.Log("email", info.GetEmail()) t.Log("email", info.Email)
accessToken = tokens.AccessToken accessToken = tokens.AccessToken
refreshToken = tokens.RefreshToken refreshToken = tokens.RefreshToken
idToken = tokens.IDToken idToken = tokens.IDToken
email = info.GetEmail() email = info.Email
http.Redirect(w, r, targetURL, http.StatusFound) http.Redirect(w, r, targetURL, 302)
} }
rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get) rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get)
@ -261,7 +261,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
} }
}() }()
require.Less(t, capturedW.Code, 400, "token exchange response code") require.Less(t, capturedW.Code, 400, "token exchange response code")
require.Less(t, capturedW.Code, 400, "token exchange response code")
//nolint:bodyclose //nolint:bodyclose
resp = capturedW.Result() resp = capturedW.Result()

View file

@ -394,7 +394,7 @@ func CodeExchange(ctx context.Context, code string, rp RelyingParty, opts ...Cod
return nil, errors.New("id_token missing") return nil, errors.New("id_token missing")
} }
idToken, err := VerifyTokens(ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) idToken, err := VerifyTokens[*oidc.IDTokenClaims](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -445,14 +445,14 @@ func CodeExchangeHandler(callback CodeExchangeCallback, rp RelyingParty) http.Ha
} }
} }
type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info oidc.UserInfo) type CodeExchangeUserinfoCallback func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, provider RelyingParty, info *oidc.UserInfo)
// UserinfoCallback wraps the callback function of the CodeExchangeHandler // UserinfoCallback wraps the callback function of the CodeExchangeHandler
// and calls the userinfo endpoint with the access token // and calls the userinfo endpoint with the access token
// on success it will pass the userinfo into its callback function as well // on success it will pass the userinfo into its callback function as well
func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback { func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp RelyingParty) {
info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, rp)
if err != nil { if err != nil {
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
return return
@ -462,17 +462,17 @@ func UserinfoCallback(f CodeExchangeUserinfoCallback) CodeExchangeCallback {
} }
// Userinfo will call the OIDC Userinfo Endpoint with the provided token // Userinfo will call the OIDC Userinfo Endpoint with the provided token
func Userinfo(token, tokenType, subject string, rp RelyingParty) (oidc.UserInfo, error) { func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) {
req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("authorization", tokenType+" "+token) req.Header.Set("authorization", tokenType+" "+token)
userinfo := oidc.NewUserInfo() userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
return nil, err return nil, err
} }
if userinfo.GetSubject() != subject { if userinfo.Subject != subject {
return nil, ErrUserInfoSubNotMatching return nil, ErrUserInfoSubNotMatching
} }
return userinfo, nil return userinfo, nil

View file

@ -21,69 +21,71 @@ type IDTokenVerifier interface {
// VerifyTokens implement the Token Response Validation as defined in OIDC specification // VerifyTokens implement the Token Response Validation as defined in OIDC specification
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
func VerifyTokens(ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idTokenString string, v IDTokenVerifier) (claims C, err error) {
idToken, err := VerifyIDToken(ctx, idTokenString, v) var nilClaims C
idToken, err := VerifyIDToken[C](ctx, idTokenString, v)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil { if err := VerifyAccessToken(accessToken, idToken.GetAccessTokenHash(), idToken.GetSignatureAlgorithm()); err != nil {
return nil, err return nilClaims, err
} }
return idToken, nil return idToken, nil
} }
// VerifyIDToken validates the id token according to // VerifyIDToken validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDToken(ctx context.Context, token string, v IDTokenVerifier) (oidc.IDTokenClaims, error) { func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) {
claims := oidc.EmptyIDTokenClaims() var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
payload, err := oidc.ParseToken(decrypted, claims) payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
if err := oidc.CheckSubject(claims); err != nil { if err := oidc.CheckSubject(claims); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { if err = oidc.CheckAudience(claims, v.ClientID()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil { if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err return nilClaims, err
} }
return claims, nil return claims, nil
} }

View file

@ -112,7 +112,7 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option {
} }
} }
func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.IntrospectionResponse, error) { func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) {
authFn, err := rp.AuthFn() authFn, err := rp.AuthFn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -121,7 +121,7 @@ func Introspect(ctx context.Context, rp ResourceServer, token string) (oidc.Intr
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp := oidc.NewIntrospectionResponse() resp := new(oidc.IntrospectionResponse)
if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil {
return nil, err return nil, err
} }

84
pkg/oidc/claims.go Normal file
View file

@ -0,0 +1,84 @@
package oidc
// Some expirimental stuff, no sure yet if it can be used
// or deleted before final PR.
/*
// CustomClaims allows the joining of any type
// with Registered fields and a map of custom Claims.
type CustomClaims[R any] struct {
Registered R
Claims map[string]any
}
func (c *CustomClaims[_]) AppendClaims(k string, v any) {
if c.Claims == nil {
c.Claims = make(map[string]any)
}
c.Claims[k] = v
}
// MarshalJSON implements the json.Marshaller interface.
// The Registered and Claims map are merged into a
// single JSON object. Registered fields overwrite
// custom Claims.
func (c *CustomClaims[_]) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims(&c.Registered, c.Claims)
}
// UnmashalJSON implements the json.Unmarshaller interface.
// Matching values from the JSON document are set in Registered.
// The map Claims will contain all claims from the JSON document.
func (c *CustomClaims[_]) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, &c.Registered, &c.Claims)
}
// CustomTokenClaims allows the joining of a Claims
// type with registered fields and a map of custom Claims.
// CustomTokenClaims implements the Claims interface,
// and any type that embeds TokenClaims can be used as
// type argument.
type CustomTokenClaims[TC Claims] struct {
Registered TC
Claims map[string]any
}
func (c *CustomTokenClaims[_]) AppendClaims(k string, v any) {
if c.Claims == nil {
c.Claims = make(map[string]any)
}
c.Claims[k] = v
}
// MarshalJSON implements the json.Marshaller interface.
// The Registered and Claims map are merged into a
// single JSON object. Registered fields overwrite
// custom Claims.
func (c *CustomTokenClaims[_]) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims(&c.Registered, c.Claims)
}
// UnmashalJSON implements the json.Unmarshaller interface.
// Matching values from the JSON document are set in Registered.
// The map Claims will contain all claims from the JSON document.
func (c *CustomTokenClaims[_]) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, &c.Registered, &c.Claims)
}
func (c *CustomTokenClaims[_]) GetIssuer() string { return c.Registered.GetIssuer() }
func (c *CustomTokenClaims[_]) GetSubject() string { return c.Registered.GetSubject() }
func (c *CustomTokenClaims[_]) GetAudience() []string { return c.Registered.GetAudience() }
func (c *CustomTokenClaims[_]) GetExpiration() time.Time { return c.Registered.GetExpiration() }
func (c *CustomTokenClaims[_]) GetIssuedAt() time.Time { return c.Registered.GetIssuedAt() }
func (c *CustomTokenClaims[_]) GetNonce() string { return c.Registered.GetNonce() }
func (c *CustomTokenClaims[_]) GetAuthTime() time.Time { return c.Registered.GetAuthTime() }
func (c *CustomTokenClaims[_]) GetAuthorizedParty() string {
return c.Registered.GetAuthorizedParty()
}
func (c *CustomTokenClaims[_]) GetAuthenticationContextClassReference() string {
return c.Registered.GetAuthenticationContextClassReference()
}
func (c *CustomTokenClaims[_]) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
c.Registered.SetSignatureAlgorithm(algorithm)
}
*/

View file

@ -1,13 +1,5 @@
package oidc package oidc
import (
"encoding/json"
"fmt"
"time"
"golang.org/x/text/language"
)
type IntrospectionRequest struct { type IntrospectionRequest struct {
Token string `schema:"token"` Token string `schema:"token"`
} }
@ -17,36 +9,7 @@ type ClientAssertionParams struct {
ClientAssertionType string `schema:"client_assertion_type"` ClientAssertionType string `schema:"client_assertion_type"`
} }
type IntrospectionResponse interface { type IntrospectionResponse struct {
UserInfoSetter
IsActive() bool
SetActive(bool)
SetScopes(scopes []string)
SetClientID(id string)
SetTokenType(tokenType string)
SetExpiration(exp time.Time)
SetIssuedAt(iat time.Time)
SetNotBefore(nbf time.Time)
SetAudience(audience []string)
SetIssuer(issuer string)
SetJWTID(id string)
GetScope() []string
GetClientID() string
GetTokenType() string
GetExpiration() time.Time
GetIssuedAt() time.Time
GetNotBefore() time.Time
GetSubject() string
GetAudience() []string
GetIssuer() string
GetJWTID() string
}
func NewIntrospectionResponse() IntrospectionResponse {
return &introspectionResponse{}
}
type introspectionResponse struct {
Active bool `json:"active"` Active bool `json:"active"`
Scope SpaceDelimitedArray `json:"scope,omitempty"` Scope SpaceDelimitedArray `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"` ClientID string `json:"client_id,omitempty"`
@ -58,323 +21,47 @@ type introspectionResponse struct {
Audience Audience `json:"aud,omitempty"` Audience Audience `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"` Issuer string `json:"iss,omitempty"`
JWTID string `json:"jti,omitempty"` JWTID string `json:"jti,omitempty"`
userInfoProfile Username string `json:"username,omitempty"`
userInfoEmail UserInfoProfile
userInfoPhone UserInfoEmail
UserInfoPhone
Address UserInfoAddress `json:"address,omitempty"` Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{} Claims map[string]any `json:"-"`
} }
func (i *introspectionResponse) IsActive() bool { // GetUserInfo copies all user related fields into a new UserInfo.
return i.Active func (i *IntrospectionResponse) GetUserInfo() *UserInfo {
} return &UserInfo{
Address: i.Address,
func (i *introspectionResponse) GetSubject() string { Subject: i.Subject,
return i.Subject UserInfoProfile: i.UserInfoProfile,
} UserInfoEmail: i.UserInfoEmail,
UserInfoPhone: i.UserInfoPhone,
func (i *introspectionResponse) GetName() string {
return i.Name
}
func (i *introspectionResponse) GetGivenName() string {
return i.GivenName
}
func (i *introspectionResponse) GetFamilyName() string {
return i.FamilyName
}
func (i *introspectionResponse) GetMiddleName() string {
return i.MiddleName
}
func (i *introspectionResponse) GetNickname() string {
return i.Nickname
}
func (i *introspectionResponse) GetProfile() string {
return i.Profile
}
func (i *introspectionResponse) GetPicture() string {
return i.Picture
}
func (i *introspectionResponse) GetWebsite() string {
return i.Website
}
func (i *introspectionResponse) GetGender() Gender {
return i.Gender
}
func (i *introspectionResponse) GetBirthdate() string {
return i.Birthdate
}
func (i *introspectionResponse) GetZoneinfo() string {
return i.Zoneinfo
}
func (i *introspectionResponse) GetLocale() language.Tag {
return i.Locale
}
func (i *introspectionResponse) GetPreferredUsername() string {
return i.PreferredUsername
}
func (i *introspectionResponse) GetEmail() string {
return i.Email
}
func (i *introspectionResponse) IsEmailVerified() bool {
return bool(i.EmailVerified)
}
func (i *introspectionResponse) GetPhoneNumber() string {
return i.PhoneNumber
}
func (i *introspectionResponse) IsPhoneNumberVerified() bool {
return i.PhoneNumberVerified
}
func (i *introspectionResponse) GetAddress() UserInfoAddress {
return i.Address
}
func (i *introspectionResponse) GetClaim(key string) interface{} {
return i.claims[key]
}
func (i *introspectionResponse) GetClaims() map[string]interface{} {
return i.claims
}
func (i *introspectionResponse) GetScope() []string {
return []string(i.Scope)
}
func (i *introspectionResponse) GetClientID() string {
return i.ClientID
}
func (i *introspectionResponse) GetTokenType() string {
return i.TokenType
}
func (i *introspectionResponse) GetExpiration() time.Time {
return time.Time(i.Expiration)
}
func (i *introspectionResponse) GetIssuedAt() time.Time {
return time.Time(i.IssuedAt)
}
func (i *introspectionResponse) GetNotBefore() time.Time {
return time.Time(i.NotBefore)
}
func (i *introspectionResponse) GetAudience() []string {
return []string(i.Audience)
}
func (i *introspectionResponse) GetIssuer() string {
return i.Issuer
}
func (i *introspectionResponse) GetJWTID() string {
return i.JWTID
}
func (i *introspectionResponse) SetActive(active bool) {
i.Active = active
}
func (i *introspectionResponse) SetScopes(scope []string) {
i.Scope = scope
}
func (i *introspectionResponse) SetClientID(id string) {
i.ClientID = id
}
func (i *introspectionResponse) SetTokenType(tokenType string) {
i.TokenType = tokenType
}
func (i *introspectionResponse) SetExpiration(exp time.Time) {
i.Expiration = Time(exp)
}
func (i *introspectionResponse) SetIssuedAt(iat time.Time) {
i.IssuedAt = Time(iat)
}
func (i *introspectionResponse) SetNotBefore(nbf time.Time) {
i.NotBefore = Time(nbf)
}
func (i *introspectionResponse) SetAudience(audience []string) {
i.Audience = audience
}
func (i *introspectionResponse) SetIssuer(issuer string) {
i.Issuer = issuer
}
func (i *introspectionResponse) SetJWTID(id string) {
i.JWTID = id
}
func (i *introspectionResponse) SetSubject(sub string) {
i.Subject = sub
}
func (i *introspectionResponse) SetName(name string) {
i.Name = name
}
func (i *introspectionResponse) SetGivenName(name string) {
i.GivenName = name
}
func (i *introspectionResponse) SetFamilyName(name string) {
i.FamilyName = name
}
func (i *introspectionResponse) SetMiddleName(name string) {
i.MiddleName = name
}
func (i *introspectionResponse) SetNickname(name string) {
i.Nickname = name
}
func (i *introspectionResponse) SetUpdatedAt(date time.Time) {
i.UpdatedAt = Time(date)
}
func (i *introspectionResponse) SetProfile(profile string) {
i.Profile = profile
}
func (i *introspectionResponse) SetPicture(picture string) {
i.Picture = picture
}
func (i *introspectionResponse) SetWebsite(website string) {
i.Website = website
}
func (i *introspectionResponse) SetGender(gender Gender) {
i.Gender = gender
}
func (i *introspectionResponse) SetBirthdate(birthdate string) {
i.Birthdate = birthdate
}
func (i *introspectionResponse) SetZoneinfo(zoneInfo string) {
i.Zoneinfo = zoneInfo
}
func (i *introspectionResponse) SetLocale(locale language.Tag) {
i.Locale = locale
}
func (i *introspectionResponse) SetPreferredUsername(name string) {
i.PreferredUsername = name
}
func (i *introspectionResponse) SetEmail(email string, verified bool) {
i.Email = email
i.EmailVerified = boolString(verified)
}
func (i *introspectionResponse) SetPhone(phone string, verified bool) {
i.PhoneNumber = phone
i.PhoneNumberVerified = verified
}
func (i *introspectionResponse) SetAddress(address UserInfoAddress) {
i.Address = address
}
func (i *introspectionResponse) AppendClaims(key string, value interface{}) {
if i.claims == nil {
i.claims = make(map[string]interface{})
} }
i.claims[key] = value
} }
func (i *introspectionResponse) MarshalJSON() ([]byte, error) { // SetUserInfo copies all relevant fields from UserInfo
type Alias introspectionResponse // into the IntroSpectionResponse.
a := &struct { func (i *IntrospectionResponse) SetUserInfo(u *UserInfo) {
*Alias i.Subject = u.Subject
Expiration int64 `json:"exp,omitempty"` i.Username = i.PreferredUsername
IssuedAt int64 `json:"iat,omitempty"` i.Address = u.Address
NotBefore int64 `json:"nbf,omitempty"` i.UserInfoProfile = u.UserInfoProfile
Locale interface{} `json:"locale,omitempty"` i.UserInfoEmail = u.UserInfoEmail
UpdatedAt int64 `json:"updated_at,omitempty"` i.UserInfoPhone = u.UserInfoPhone
Username string `json:"username,omitempty"`
}{
Alias: (*Alias)(i),
}
if !i.Locale.IsRoot() {
a.Locale = i.Locale
}
if !time.Time(i.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(i.UpdatedAt).Unix()
}
if !time.Time(i.Expiration).IsZero() {
a.Expiration = time.Time(i.Expiration).Unix()
}
if !time.Time(i.IssuedAt).IsZero() {
a.IssuedAt = time.Time(i.IssuedAt).Unix()
}
if !time.Time(i.NotBefore).IsZero() {
a.NotBefore = time.Time(i.NotBefore).Unix()
}
a.Username = i.PreferredUsername
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(i.claims) == 0 {
return b, nil
}
err = json.Unmarshal(b, &i.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", i.claims)
}
return json.Marshal(i.claims)
} }
func (i *introspectionResponse) UnmarshalJSON(data []byte) error { // introspectionResponseAlias prevents loops on the JSON methods
type Alias introspectionResponse type introspectionResponseAlias IntrospectionResponse
a := &struct {
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(i),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
i.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC()) func (i *IntrospectionResponse) MarshalJSON() ([]byte, error) {
//TODO: set the username directly where the IntrospectionResponse is created
// a.Username = i.PreferredUsername
if err := json.Unmarshal(data, &i.claims); err != nil { return mergeAndMarshalClaims((*introspectionResponseAlias)(i), i.Claims)
return err }
}
func (i *IntrospectionResponse) UnmarshalJSON(data []byte) error {
return nil return unmarshalJSONMulti(data, (*introspectionResponseAlias)(i), &i.Claims)
} }

View file

@ -9,7 +9,6 @@ import (
"path" "path"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/text/language" "golang.org/x/text/language"
@ -22,8 +21,10 @@ const dataDir = "regression_data"
// dataDir/<type_name>.json // dataDir/<type_name>.json
func jsonFilename(obj interface{}) string { func jsonFilename(obj interface{}) string {
name := fmt.Sprintf("%T.json", obj) name := fmt.Sprintf("%T.json", obj)
name, _ = strings.CutPrefix(name, "*") return path.Join(
return path.Join(dataDir, name) dataDir,
strings.TrimPrefix(name, "*"),
)
} }
func encodeJSON(t *testing.T, w io.Writer, obj interface{}) { func encodeJSON(t *testing.T, w io.Writer, obj interface{}) {
@ -33,70 +34,86 @@ func encodeJSON(t *testing.T, w io.Writer, obj interface{}) {
} }
var ( var (
accessTokenRegressData = &accessTokenClaims{ accessTokenRegressData = &AccessTokenClaims{
RegisteredAccessTokenClaims: RegisteredAccessTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel", Issuer: "zitadel",
Subject: "hello@me.com", Subject: "hello@me.com",
Audience: Audience{"foo", "bar"}, Audience: Audience{"foo", "bar"},
Expiration: Time(time.Unix(12345, 0)), Expiration: 12345,
IssuedAt: Time(time.Unix(12000, 0)), IssuedAt: 12000,
NotBefore: Time(time.Unix(12000, 0)),
JWTID: "900", JWTID: "900",
AuthorizedParty: "just@me.com", AuthorizedParty: "just@me.com",
Nonce: "6969", Nonce: "6969",
AuthTime: Time(time.Unix(12000, 0)), AuthTime: 12000,
CodeHash: "hashhash",
AuthenticationContextClassReference: "something", AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"}, AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777",
SignatureAlg: jose.ES256,
},
NotBefore: 12000,
CodeHash: "hashhash",
SessionID: "666", SessionID: "666",
Scopes: []string{"email", "phone"}, Scopes: []string{"email", "phone"},
ClientID: "777",
AccessTokenUseNumber: 22, AccessTokenUseNumber: 22,
claims: map[string]interface{}{ },
Claims: map[string]interface{}{
"foo": "bar", "foo": "bar",
}, },
signatureAlg: jose.ES256,
} }
idTokenRegressData = &idTokenClaims{ idTokenRegressData = &IDTokenClaims{
RegisteredIDTokenClaims: RegisteredIDTokenClaims{
TokenClaims: TokenClaims{
Issuer: "zitadel", Issuer: "zitadel",
Subject: "hello@me.com",
Audience: Audience{"foo", "bar"}, Audience: Audience{"foo", "bar"},
Expiration: Time(time.Unix(12345, 0)), Expiration: 12345,
NotBefore: Time(time.Unix(12000, 0)), IssuedAt: 12000,
IssuedAt: Time(time.Unix(12000, 0)),
JWTID: "900", JWTID: "900",
AuthorizedParty: "just@me.com", AuthorizedParty: "just@me.com",
Nonce: "6969", Nonce: "6969",
AuthTime: Time(time.Unix(12000, 0)), AuthTime: 12000,
AccessTokenHash: "acthashhash",
CodeHash: "hashhash",
AuthenticationContextClassReference: "something", AuthenticationContextClassReference: "something",
AuthenticationMethodsReferences: []string{"some", "methods"}, AuthenticationMethodsReferences: []string{"some", "methods"},
ClientID: "777", ClientID: "777",
UserInfo: userInfoRegressData, SignatureAlg: jose.ES256,
signatureAlg: jose.ES256, },
NotBefore: 12000,
AccessTokenHash: "acthashhash",
CodeHash: "hashhash",
UserInfoProfile: userInfoRegressData.UserInfoProfile,
UserInfoEmail: userInfoRegressData.UserInfoEmail,
UserInfoPhone: userInfoRegressData.UserInfoPhone,
Address: userInfoRegressData.Address,
},
Claims: map[string]interface{}{
"foo": "bar",
},
} }
introspectionResponseRegressData = &introspectionResponse{ introspectionResponseRegressData = &IntrospectionResponse{
Active: true, Active: true,
Scope: SpaceDelimitedArray{"email", "phone"}, Scope: SpaceDelimitedArray{"email", "phone"},
ClientID: "777", ClientID: "777",
TokenType: "idtoken", TokenType: "idtoken",
Expiration: Time(time.Unix(12345, 0)), Expiration: 12345,
IssuedAt: Time(time.Unix(12000, 0)), IssuedAt: 12000,
NotBefore: Time(time.Unix(12000, 0)), NotBefore: 12000,
Subject: "hello@me.com", Subject: "hello@me.com",
Audience: Audience{"foo", "bar"}, Audience: Audience{"foo", "bar"},
Issuer: "zitadel", Issuer: "zitadel",
JWTID: "900", JWTID: "900",
userInfoProfile: userInfoRegressData.userInfoProfile, Username: "muhlemmer",
userInfoEmail: userInfoRegressData.userInfoEmail, UserInfoProfile: userInfoRegressData.UserInfoProfile,
userInfoPhone: userInfoRegressData.userInfoPhone, UserInfoEmail: userInfoRegressData.UserInfoEmail,
UserInfoPhone: userInfoRegressData.UserInfoPhone,
Address: userInfoRegressData.Address, Address: userInfoRegressData.Address,
claims: map[string]interface{}{ Claims: map[string]interface{}{
"foo": "bar", "foo": "bar",
}, },
} }
userInfoRegressData = &userinfo{ userInfoRegressData = &UserInfo{
Subject: "hello@me.com", Subject: "hello@me.com",
userInfoProfile: userInfoProfile{ UserInfoProfile: UserInfoProfile{
Name: "Tim Möhlmann", Name: "Tim Möhlmann",
GivenName: "Tim", GivenName: "Tim",
FamilyName: "Möhlmann", FamilyName: "Möhlmann",
@ -108,19 +125,19 @@ var (
Gender: "male", Gender: "male",
Birthdate: "1st of April", Birthdate: "1st of April",
Zoneinfo: "Europe/Amsterdam", Zoneinfo: "Europe/Amsterdam",
Locale: language.Dutch, Locale: NewLocale(language.Dutch),
UpdatedAt: Time(time.Unix(1, 1)), UpdatedAt: 1,
PreferredUsername: "muhlemmer", PreferredUsername: "muhlemmer",
}, },
userInfoEmail: userInfoEmail{ UserInfoEmail: UserInfoEmail{
Email: "tim@zitadel.com", Email: "tim@zitadel.com",
EmailVerified: true, EmailVerified: true,
}, },
userInfoPhone: userInfoPhone{ UserInfoPhone: UserInfoPhone{
PhoneNumber: "+1234567890", PhoneNumber: "+1234567890",
PhoneNumberVerified: true, PhoneNumberVerified: true,
}, },
Address: &userInfoAddress{ Address: UserInfoAddress{
Formatted: "Sesame street 666\n666-666, Smallvile\nMoon", Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
StreetAddress: "Sesame street 666", StreetAddress: "Sesame street 666",
Locality: "Smallvile", Locality: "Smallvile",
@ -128,7 +145,7 @@ var (
PostalCode: "666-666", PostalCode: "666-666",
Country: "Moon", Country: "Moon",
}, },
claims: map[string]interface{}{ Claims: map[string]interface{}{
"foo": "bar", "foo": "bar",
}, },
} }
@ -138,8 +155,8 @@ var (
Issuer: "zitadel", Issuer: "zitadel",
Subject: "hello@me.com", Subject: "hello@me.com",
Audience: Audience{"foo", "bar"}, Audience: Audience{"foo", "bar"},
Expiration: Time(time.Unix(12345, 0)), Expiration: 12345,
IssuedAt: Time(time.Unix(12000, 0)), IssuedAt: 12000,
customClaims: map[string]interface{}{ customClaims: map[string]interface{}{
"foo": "bar", "foo": "bar",
}, },

View file

@ -10,7 +10,6 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/zitadel/oidc/v2/pkg/crypto" "github.com/zitadel/oidc/v2/pkg/crypto"
"github.com/zitadel/oidc/v2/pkg/http"
) )
const ( const (
@ -22,378 +21,154 @@ const (
type Tokens struct { type Tokens struct {
*oauth2.Token *oauth2.Token
IDTokenClaims IDTokenClaims IDTokenClaims *IDTokenClaims
IDToken string IDToken string
} }
type AccessTokenClaims interface { // TokenClaims contains the base Claims used all tokens.
Claims // It implements OpenID Connect Core 1.0, section 2.
GetSubject() string // https://openid.net/specs/openid-connect-core-1_0.html#IDToken
GetTokenID() string // And RFC 9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens,
SetPrivateClaims(map[string]interface{}) // section 2.2. https://datatracker.ietf.org/doc/html/rfc9068#name-data-structure
GetClaims() map[string]interface{} //
} // TokenClaims implements the Claims interface,
// and can be used to extend larger claim types by embedding.
type IDTokenClaims interface { type TokenClaims struct {
Claims
GetNotBefore() time.Time
GetJWTID() string
GetAccessTokenHash() string
GetCodeHash() string
GetAuthenticationMethodsReferences() []string
GetClientID() string
GetSignatureAlgorithm() jose.SignatureAlgorithm
SetAccessTokenHash(hash string)
SetUserinfo(userinfo UserInfo)
SetCodeHash(hash string)
UserInfo
}
func EmptyAccessTokenClaims() AccessTokenClaims {
return new(accessTokenClaims)
}
func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id, clientID string, skew time.Duration) AccessTokenClaims {
now := time.Now().UTC().Add(-skew)
if len(audience) == 0 {
audience = append(audience, clientID)
}
return &accessTokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: Time(expiration),
IssuedAt: Time(now),
NotBefore: Time(now),
JWTID: id,
}
}
type accessTokenClaims struct {
Issuer string `json:"iss,omitempty"` Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"` Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"` Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"` Expiration Time `json:"exp,omitempty"`
IssuedAt Time `json:"iat,omitempty"` IssuedAt Time `json:"iat,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"` AuthTime Time `json:"auth_time,omitempty"`
CodeHash string `json:"c_hash,omitempty"` Nonce string `json:"nonce,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"` AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"` AuthenticationMethodsReferences []string `json:"amr,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
ClientID string `json:"client_id,omitempty"`
JWTID string `json:"jti,omitempty"`
// Additional information set by this framework
SignatureAlg jose.SignatureAlgorithm `json:"-"`
}
func (c *TokenClaims) GetIssuer() string { return c.Issuer }
func (c *TokenClaims) GetSubject() string { return c.Subject }
func (c *TokenClaims) GetAudience() []string { return c.Audience }
func (c *TokenClaims) GetExpiration() time.Time { return c.Expiration.AsTime() }
func (c *TokenClaims) GetIssuedAt() time.Time { return c.IssuedAt.AsTime() }
func (c *TokenClaims) GetNonce() string { return c.Nonce }
func (c *TokenClaims) GetAuthTime() time.Time { return c.AuthTime.AsTime() }
func (c *TokenClaims) GetAuthorizedParty() string { return c.AuthorizedParty }
func (c *TokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm { return c.SignatureAlg }
func (c *TokenClaims) GetAuthenticationContextClassReference() string {
return c.AuthenticationContextClassReference
}
func (c *TokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
c.SignatureAlg = algorithm
}
type RegisteredAccessTokenClaims struct {
TokenClaims
NotBefore Time `json:"nbf,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
SessionID string `json:"sid,omitempty"` SessionID string `json:"sid,omitempty"`
Scopes []string `json:"scope,omitempty"` Scopes []string `json:"scope,omitempty"`
ClientID string `json:"client_id,omitempty"`
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
claims map[string]interface{} `json:"-"`
signatureAlg jose.SignatureAlgorithm `json:"-"`
} }
// GetIssuer implements the Claims interface type AccessTokenClaims struct {
func (a *accessTokenClaims) GetIssuer() string { RegisteredAccessTokenClaims
return a.Issuer
Claims map[string]any `json:"-"`
} }
// GetAudience implements the Claims interface func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, id, clientID string, skew time.Duration) *AccessTokenClaims {
func (a *accessTokenClaims) GetAudience() []string { now := time.Now().UTC().Add(-skew)
return a.Audience if len(audience) == 0 {
} audience = append(audience, clientID)
// GetExpiration implements the Claims interface
func (a *accessTokenClaims) GetExpiration() time.Time {
return time.Time(a.Expiration)
}
// GetIssuedAt implements the Claims interface
func (a *accessTokenClaims) GetIssuedAt() time.Time {
return time.Time(a.IssuedAt)
}
// GetNonce implements the Claims interface
func (a *accessTokenClaims) GetNonce() string {
return a.Nonce
}
// GetAuthenticationContextClassReference implements the Claims interface
func (a *accessTokenClaims) GetAuthenticationContextClassReference() string {
return a.AuthenticationContextClassReference
}
// GetAuthTime implements the Claims interface
func (a *accessTokenClaims) GetAuthTime() time.Time {
return time.Time(a.AuthTime)
}
// GetAuthorizedParty implements the Claims interface
func (a *accessTokenClaims) GetAuthorizedParty() string {
return a.AuthorizedParty
}
// SetSignatureAlgorithm implements the Claims interface
func (a *accessTokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
a.signatureAlg = algorithm
}
// GetSubject implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetSubject() string {
return a.Subject
}
// GetTokenID implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetTokenID() string {
return a.JWTID
}
// SetPrivateClaims implements the AccessTokenClaims interface
func (a *accessTokenClaims) SetPrivateClaims(claims map[string]interface{}) {
a.claims = claims
}
// GetClaims implements the AccessTokenClaims interface
func (a *accessTokenClaims) GetClaims() map[string]interface{} {
return a.claims
}
func (a *accessTokenClaims) MarshalJSON() ([]byte, error) {
type Alias accessTokenClaims
s := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(a),
} }
if !time.Time(a.Expiration).IsZero() { return &AccessTokenClaims{
s.Expiration = time.Time(a.Expiration).Unix() RegisteredAccessTokenClaims: RegisteredAccessTokenClaims{
} TokenClaims: TokenClaims{
if !time.Time(a.IssuedAt).IsZero() {
s.IssuedAt = time.Time(a.IssuedAt).Unix()
}
if !time.Time(a.NotBefore).IsZero() {
s.NotBefore = time.Time(a.NotBefore).Unix()
}
if !time.Time(a.AuthTime).IsZero() {
s.AuthTime = time.Time(a.AuthTime).Unix()
}
b, err := json.Marshal(s)
if err != nil {
return nil, err
}
if a.claims == nil {
return b, nil
}
info, err := json.Marshal(a.claims)
if err != nil {
return nil, err
}
return http.ConcatenateJSON(b, info)
}
func (a *accessTokenClaims) UnmarshalJSON(data []byte) error {
type Alias accessTokenClaims
if err := json.Unmarshal(data, (*Alias)(a)); err != nil {
return err
}
claims := make(map[string]interface{})
if err := json.Unmarshal(data, &claims); err != nil {
return err
}
a.claims = claims
return nil
}
func EmptyIDTokenClaims() IDTokenClaims {
return new(idTokenClaims)
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) IDTokenClaims {
audience = AppendClientIDToAudience(clientID, audience)
return &idTokenClaims{
Issuer: issuer, Issuer: issuer,
Subject: subject,
Audience: audience, Audience: audience,
Expiration: Time(expiration), Expiration: FromTime(expiration),
IssuedAt: Time(time.Now().UTC().Add(-skew)), IssuedAt: FromTime(now),
AuthTime: Time(authTime.Add(-skew)), JWTID: id,
},
NotBefore: FromTime(now),
},
}
}
type atcAlias AccessTokenClaims
func (a *AccessTokenClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*atcAlias)(a), a.Claims)
}
func (a *AccessTokenClaims) UnmarshalJSON(data []byte) error {
return unmarshalJSONMulti(data, (*atcAlias)(a), &a.Claims)
}
type RegisteredIDTokenClaims struct {
TokenClaims
NotBefore Time `json:"nbf,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
UserInfoProfile
UserInfoEmail
UserInfoPhone
Address UserInfoAddress `json:"address,omitempty"`
}
// GetAccessTokenHash implements the IDTokenClaims interface
func (t *RegisteredIDTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
func (t *RegisteredIDTokenClaims) SetUserInfo(i *UserInfo) {
t.Subject = i.Subject
t.UserInfoProfile = i.UserInfoProfile
t.UserInfoEmail = i.UserInfoEmail
t.UserInfoPhone = i.UserInfoPhone
t.Address = i.Address
}
type IDTokenClaims struct {
RegisteredIDTokenClaims
Claims map[string]any `json:"-"`
}
func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {
audience = AppendClientIDToAudience(clientID, audience)
return &IDTokenClaims{
RegisteredIDTokenClaims: RegisteredIDTokenClaims{
TokenClaims: TokenClaims{
Issuer: issuer,
Subject: subject,
Audience: audience,
Expiration: FromTime(expiration),
IssuedAt: FromTime(time.Now().Add(-skew)),
AuthTime: FromTime(authTime.Add(-skew)),
Nonce: nonce, Nonce: nonce,
AuthenticationContextClassReference: acr, AuthenticationContextClassReference: acr,
AuthenticationMethodsReferences: amr, AuthenticationMethodsReferences: amr,
AuthorizedParty: clientID, AuthorizedParty: clientID,
UserInfo: &userinfo{Subject: subject}, },
},
} }
} }
type idTokenClaims struct { type itcAlias IDTokenClaims
Issuer string `json:"iss,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiration Time `json:"exp,omitempty"`
NotBefore Time `json:"nbf,omitempty"`
IssuedAt Time `json:"iat,omitempty"`
JWTID string `json:"jti,omitempty"`
AuthorizedParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime Time `json:"auth_time,omitempty"`
AccessTokenHash string `json:"at_hash,omitempty"`
CodeHash string `json:"c_hash,omitempty"`
AuthenticationContextClassReference string `json:"acr,omitempty"`
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
ClientID string `json:"client_id,omitempty"`
UserInfo `json:"-"`
signatureAlg jose.SignatureAlgorithm func (i *IDTokenClaims) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*itcAlias)(i), i.Claims)
} }
// GetIssuer implements the Claims interface func (i *IDTokenClaims) UnmarshalJSON(data []byte) error {
func (t *idTokenClaims) GetIssuer() string { return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims)
return t.Issuer
}
// GetAudience implements the Claims interface
func (t *idTokenClaims) GetAudience() []string {
return t.Audience
}
// GetExpiration implements the Claims interface
func (t *idTokenClaims) GetExpiration() time.Time {
return time.Time(t.Expiration)
}
// GetIssuedAt implements the Claims interface
func (t *idTokenClaims) GetIssuedAt() time.Time {
return time.Time(t.IssuedAt)
}
// GetNonce implements the Claims interface
func (t *idTokenClaims) GetNonce() string {
return t.Nonce
}
// GetAuthenticationContextClassReference implements the Claims interface
func (t *idTokenClaims) GetAuthenticationContextClassReference() string {
return t.AuthenticationContextClassReference
}
// GetAuthTime implements the Claims interface
func (t *idTokenClaims) GetAuthTime() time.Time {
return time.Time(t.AuthTime)
}
// GetAuthorizedParty implements the Claims interface
func (t *idTokenClaims) GetAuthorizedParty() string {
return t.AuthorizedParty
}
// SetSignatureAlgorithm implements the Claims interface
func (t *idTokenClaims) SetSignatureAlgorithm(alg jose.SignatureAlgorithm) {
t.signatureAlg = alg
}
// GetNotBefore implements the IDTokenClaims interface
func (t *idTokenClaims) GetNotBefore() time.Time {
return time.Time(t.NotBefore)
}
// GetJWTID implements the IDTokenClaims interface
func (t *idTokenClaims) GetJWTID() string {
return t.JWTID
}
// GetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetAccessTokenHash() string {
return t.AccessTokenHash
}
// GetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) GetCodeHash() string {
return t.CodeHash
}
// GetAuthenticationMethodsReferences implements the IDTokenClaims interface
func (t *idTokenClaims) GetAuthenticationMethodsReferences() []string {
return t.AuthenticationMethodsReferences
}
// GetClientID implements the IDTokenClaims interface
func (t *idTokenClaims) GetClientID() string {
return t.ClientID
}
// GetSignatureAlgorithm implements the IDTokenClaims interface
func (t *idTokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
return t.signatureAlg
}
// SetAccessTokenHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetAccessTokenHash(hash string) {
t.AccessTokenHash = hash
}
// SetUserinfo implements the IDTokenClaims interface
func (t *idTokenClaims) SetUserinfo(info UserInfo) {
t.UserInfo = info
}
// SetCodeHash implements the IDTokenClaims interface
func (t *idTokenClaims) SetCodeHash(hash string) {
t.CodeHash = hash
}
func (t *idTokenClaims) MarshalJSON() ([]byte, error) {
type Alias idTokenClaims
a := &struct {
*Alias
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
}{
Alias: (*Alias)(t),
}
if !time.Time(t.Expiration).IsZero() {
a.Expiration = time.Time(t.Expiration).Unix()
}
if !time.Time(t.IssuedAt).IsZero() {
a.IssuedAt = time.Time(t.IssuedAt).Unix()
}
if !time.Time(t.NotBefore).IsZero() {
a.NotBefore = time.Time(t.NotBefore).Unix()
}
if !time.Time(t.AuthTime).IsZero() {
a.AuthTime = time.Time(t.AuthTime).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if t.UserInfo == nil {
return b, nil
}
info, err := json.Marshal(t.UserInfo)
if err != nil {
return nil, err
}
return http.ConcatenateJSON(b, info)
}
func (t *idTokenClaims) UnmarshalJSON(data []byte) error {
type Alias idTokenClaims
if err := json.Unmarshal(data, (*Alias)(t)); err != nil {
return err
}
userinfo := new(userinfo)
if err := json.Unmarshal(data, userinfo); err != nil {
return err
}
t.UserInfo = userinfo
return nil
} }
type AccessTokenResponse struct { type AccessTokenResponse struct {
@ -502,11 +277,11 @@ func (j *jwtProfileAssertion) GetAudience() []string {
} }
func (j *jwtProfileAssertion) GetExpiration() time.Time { func (j *jwtProfileAssertion) GetExpiration() time.Time {
return time.Time(j.Expiration) return j.Expiration.AsTime()
} }
func (j *jwtProfileAssertion) GetIssuedAt() time.Time { func (j *jwtProfileAssertion) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt) return j.IssuedAt.AsTime()
} }
func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) { func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (JWTProfileAssertionClaims, error) {
@ -563,8 +338,8 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte,
PrivateKeyID: keyID, PrivateKeyID: keyID,
Issuer: userID, Issuer: userID,
Subject: userID, Subject: userID,
IssuedAt: Time(time.Now().UTC()), IssuedAt: FromTime(time.Now().UTC()),
Expiration: Time(time.Now().Add(1 * time.Hour).UTC()), Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()),
Audience: audience, Audience: audience,
customClaims: make(map[string]interface{}), customClaims: make(map[string]interface{}),
} }

View file

@ -187,12 +187,12 @@ func (j *JWTTokenRequest) GetAudience() []string {
// GetExpiration implements the Claims interface // GetExpiration implements the Claims interface
func (j *JWTTokenRequest) GetExpiration() time.Time { func (j *JWTTokenRequest) GetExpiration() time.Time {
return time.Time(j.ExpiresAt) return j.ExpiresAt.AsTime()
} }
// GetIssuedAt implements the Claims interface // GetIssuedAt implements the Claims interface
func (j *JWTTokenRequest) GetIssuedAt() time.Time { func (j *JWTTokenRequest) GetIssuedAt() time.Time {
return time.Time(j.IssuedAt) return j.ExpiresAt.AsTime()
} }
// GetNonce implements the Claims interface // GetNonce implements the Claims interface

View file

@ -46,6 +46,39 @@ func (d *Display) UnmarshalText(text []byte) error {
type Gender string type Gender string
type Locale struct {
tag language.Tag
}
func NewLocale(tag language.Tag) *Locale {
return &Locale{tag: tag}
}
func (l *Locale) Tag() language.Tag {
if l == nil {
return language.Und
}
return l.tag
}
func (l *Locale) String() string {
return l.Tag().String()
}
func (l *Locale) MarshalJSON() ([]byte, error) {
tag := l.Tag()
if tag.IsRoot() {
return []byte("null"), nil
}
return json.Marshal(tag)
}
func (l *Locale) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &l.tag)
}
type Locales []language.Tag type Locales []language.Tag
func (l *Locales) UnmarshalText(text []byte) error { func (l *Locales) UnmarshalText(text []byte) error {
@ -137,19 +170,18 @@ func NewEncoder() *schema.Encoder {
return e return e
} }
type Time time.Time type Time int64
func (t *Time) UnmarshalJSON(data []byte) error { func (ts Time) AsTime() time.Time {
var i int64 return time.Unix(int64(ts), 0)
if err := json.Unmarshal(data, &i); err != nil {
return err
}
*t = Time(time.Unix(i, 0).UTC())
return nil
} }
func (t *Time) MarshalJSON() ([]byte, error) { func FromTime(tt time.Time) Time {
return json.Marshal(time.Time(*t).UTC().Unix()) return Time(tt.Unix())
}
func NowTime() Time {
return FromTime(time.Now())
} }
type RequestObject struct { type RequestObject struct {
@ -162,5 +194,4 @@ func (r *RequestObject) GetIssuer() string {
return r.Issuer return r.Issuer
} }
func (r *RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) { func (*RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}
}

View file

@ -10,6 +10,7 @@ import (
"github.com/gorilla/schema" "github.com/gorilla/schema"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language" "golang.org/x/text/language"
) )
@ -111,6 +112,117 @@ func TestDisplay_UnmarshalText(t *testing.T) {
} }
} }
func TestLocale_Tag(t *testing.T) {
tests := []struct {
name string
l *Locale
want language.Tag
}{
{
name: "nil",
l: nil,
want: language.Und,
},
{
name: "Und",
l: NewLocale(language.Und),
want: language.Und,
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: language.Afrikaans,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, tt.l.Tag())
})
}
}
func TestLocale_String(t *testing.T) {
tests := []struct {
name string
l *Locale
want language.Tag
}{
{
name: "nil",
l: nil,
want: language.Und,
},
{
name: "Und",
l: NewLocale(language.Und),
want: language.Und,
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: language.Afrikaans,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want.String(), tt.l.String())
})
}
}
func TestLocale_MarshalJSON(t *testing.T) {
tests := []struct {
name string
l *Locale
want string
wantErr bool
}{
{
name: "nil",
l: nil,
want: "null",
},
{
name: "und",
l: NewLocale(language.Und),
want: "null",
},
{
name: "language",
l: NewLocale(language.Afrikaans),
want: `"af"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.l)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, string(got))
})
}
}
func TestLocale_UnmarshalJSON(t *testing.T) {
type a struct {
Locale *Locale `json:"locale,omitempty"`
}
want := a{
Locale: NewLocale(language.Afrikaans),
}
const input = `{"locale": "af"}`
var got a
require.NoError(t,
json.Unmarshal([]byte(input), &got),
)
assert.Equal(t, want, got)
}
func TestLocales_UnmarshalText(t *testing.T) { func TestLocales_UnmarshalText(t *testing.T) {
type args struct { type args struct {
text []byte text []byte

View file

@ -1,292 +1,34 @@
package oidc package oidc
import ( type UserInfo struct {
"encoding/json" Subject string `json:"sub,omitempty"`
"fmt"
"time"
"golang.org/x/text/language"
)
type UserInfo interface {
GetSubject() string
UserInfoProfile UserInfoProfile
UserInfoEmail UserInfoEmail
UserInfoPhone UserInfoPhone
GetAddress() UserInfoAddress
GetClaim(key string) interface{}
GetClaims() map[string]interface{}
}
type UserInfoProfile interface {
GetName() string
GetGivenName() string
GetFamilyName() string
GetMiddleName() string
GetNickname() string
GetProfile() string
GetPicture() string
GetWebsite() string
GetGender() Gender
GetBirthdate() string
GetZoneinfo() string
GetLocale() language.Tag
GetPreferredUsername() string
}
type UserInfoEmail interface {
GetEmail() string
IsEmailVerified() bool
}
type UserInfoPhone interface {
GetPhoneNumber() string
IsPhoneNumberVerified() bool
}
type UserInfoAddress interface {
GetFormatted() string
GetStreetAddress() string
GetLocality() string
GetRegion() string
GetPostalCode() string
GetCountry() string
}
type UserInfoSetter interface {
UserInfo
SetSubject(sub string)
UserInfoProfileSetter
SetEmail(email string, verified bool)
SetPhone(phone string, verified bool)
SetAddress(address UserInfoAddress)
AppendClaims(key string, values interface{})
}
type UserInfoProfileSetter interface {
SetName(name string)
SetGivenName(name string)
SetFamilyName(name string)
SetMiddleName(name string)
SetNickname(name string)
SetUpdatedAt(date time.Time)
SetProfile(profile string)
SetPicture(profile string)
SetWebsite(website string)
SetGender(gender Gender)
SetBirthdate(birthdate string)
SetZoneinfo(zoneInfo string)
SetLocale(locale language.Tag)
SetPreferredUsername(name string)
}
func NewUserInfo() UserInfoSetter {
return &userinfo{}
}
type userinfo struct {
Subject string `json:"sub,omitempty"`
userInfoProfile
userInfoEmail
userInfoPhone
Address UserInfoAddress `json:"address,omitempty"` Address UserInfoAddress `json:"address,omitempty"`
claims map[string]interface{} Claims map[string]any `json:"-"`
} }
func (u *userinfo) GetSubject() string { func (u *UserInfo) AppendClaims(k string, v any) {
return u.Subject if u.Claims == nil {
} u.Claims = make(map[string]any)
func (u *userinfo) GetName() string {
return u.Name
}
func (u *userinfo) GetGivenName() string {
return u.GivenName
}
func (u *userinfo) GetFamilyName() string {
return u.FamilyName
}
func (u *userinfo) GetMiddleName() string {
return u.MiddleName
}
func (u *userinfo) GetNickname() string {
return u.Nickname
}
func (u *userinfo) GetProfile() string {
return u.Profile
}
func (u *userinfo) GetPicture() string {
return u.Picture
}
func (u *userinfo) GetWebsite() string {
return u.Website
}
func (u *userinfo) GetGender() Gender {
return u.Gender
}
func (u *userinfo) GetBirthdate() string {
return u.Birthdate
}
func (u *userinfo) GetZoneinfo() string {
return u.Zoneinfo
}
func (u *userinfo) GetLocale() language.Tag {
return u.Locale
}
func (u *userinfo) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *userinfo) GetEmail() string {
return u.Email
}
func (u *userinfo) IsEmailVerified() bool {
return bool(u.EmailVerified)
}
func (u *userinfo) GetPhoneNumber() string {
return u.PhoneNumber
}
func (u *userinfo) IsPhoneNumberVerified() bool {
return u.PhoneNumberVerified
}
func (u *userinfo) GetAddress() UserInfoAddress {
if u.Address == nil {
return &userInfoAddress{}
} }
return u.Address
u.Claims[k] = v
} }
func (u *userinfo) GetClaim(key string) interface{} { type uiAlias UserInfo
return u.claims[key]
func (u *UserInfo) MarshalJSON() ([]byte, error) {
return mergeAndMarshalClaims((*uiAlias)(u), u.Claims)
} }
func (u *userinfo) GetClaims() map[string]interface{} { func (u *UserInfo) UnmarshalJSON(data []byte) error {
return u.claims return unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims)
} }
func (u *userinfo) SetSubject(sub string) { type UserInfoProfile struct {
u.Subject = sub
}
func (u *userinfo) SetName(name string) {
u.Name = name
}
func (u *userinfo) SetGivenName(name string) {
u.GivenName = name
}
func (u *userinfo) SetFamilyName(name string) {
u.FamilyName = name
}
func (u *userinfo) SetMiddleName(name string) {
u.MiddleName = name
}
func (u *userinfo) SetNickname(name string) {
u.Nickname = name
}
func (u *userinfo) SetUpdatedAt(date time.Time) {
u.UpdatedAt = Time(date)
}
func (u *userinfo) SetProfile(profile string) {
u.Profile = profile
}
func (u *userinfo) SetPicture(picture string) {
u.Picture = picture
}
func (u *userinfo) SetWebsite(website string) {
u.Website = website
}
func (u *userinfo) SetGender(gender Gender) {
u.Gender = gender
}
func (u *userinfo) SetBirthdate(birthdate string) {
u.Birthdate = birthdate
}
func (u *userinfo) SetZoneinfo(zoneInfo string) {
u.Zoneinfo = zoneInfo
}
func (u *userinfo) SetLocale(locale language.Tag) {
u.Locale = locale
}
func (u *userinfo) SetPreferredUsername(name string) {
u.PreferredUsername = name
}
func (u *userinfo) SetEmail(email string, verified bool) {
u.Email = email
u.EmailVerified = boolString(verified)
}
func (u *userinfo) SetPhone(phone string, verified bool) {
u.PhoneNumber = phone
u.PhoneNumberVerified = verified
}
func (u *userinfo) SetAddress(address UserInfoAddress) {
u.Address = address
}
func (u *userinfo) AppendClaims(key string, value interface{}) {
if u.claims == nil {
u.claims = make(map[string]interface{})
}
u.claims[key] = value
}
func (u *userInfoAddress) GetFormatted() string {
return u.Formatted
}
func (u *userInfoAddress) GetStreetAddress() string {
return u.StreetAddress
}
func (u *userInfoAddress) GetLocality() string {
return u.Locality
}
func (u *userInfoAddress) GetRegion() string {
return u.Region
}
func (u *userInfoAddress) GetPostalCode() string {
return u.PostalCode
}
func (u *userInfoAddress) GetCountry() string {
return u.Country
}
type userInfoProfile struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"` GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"` FamilyName string `json:"family_name,omitempty"`
@ -298,23 +40,23 @@ type userInfoProfile struct {
Gender Gender `json:"gender,omitempty"` Gender Gender `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"` Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"` Zoneinfo string `json:"zoneinfo,omitempty"`
Locale language.Tag `json:"locale,omitempty"` Locale *Locale `json:"locale,omitempty"`
UpdatedAt Time `json:"updated_at,omitempty"` UpdatedAt Time `json:"updated_at,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"` PreferredUsername string `json:"preferred_username,omitempty"`
} }
type userInfoEmail struct { type UserInfoEmail struct {
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// Handle providers that return email_verified as a string // Handle providers that return email_verified as a string
// https://forums.aws.amazon.com/thread.jspa?messageID=949441&#949441 // https://forums.aws.amazon.com/thread.jspa?messageID=949441&#949441
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11 // https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
EmailVerified boolString `json:"email_verified,omitempty"` EmailVerified Bool `json:"email_verified,omitempty"`
} }
type boolString bool type Bool bool
func (bs *boolString) UnmarshalJSON(data []byte) error { func (bs *Bool) UnmarshalJSON(data []byte) error {
if string(data) == "true" || string(data) == `"true"` { if string(data) == "true" || string(data) == `"true"` {
*bs = true *bs = true
} }
@ -322,12 +64,12 @@ func (bs *boolString) UnmarshalJSON(data []byte) error {
return nil return nil
} }
type userInfoPhone struct { type UserInfoPhone struct {
PhoneNumber string `json:"phone_number,omitempty"` PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified bool `json:"phone_number_verified,omitempty"` PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
} }
type userInfoAddress struct { type UserInfoAddress struct {
Formatted string `json:"formatted,omitempty"` Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"` StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"` Locality string `json:"locality,omitempty"`
@ -336,76 +78,6 @@ type userInfoAddress struct {
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
} }
func NewUserInfoAddress(streetAddress, locality, region, postalCode, country, formatted string) UserInfoAddress {
return &userInfoAddress{
StreetAddress: streetAddress,
Locality: locality,
Region: region,
PostalCode: postalCode,
Country: country,
Formatted: formatted,
}
}
func (u *userinfo) MarshalJSON() ([]byte, error) {
type Alias userinfo
a := &struct {
*Alias
Locale interface{} `json:"locale,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
}{
Alias: (*Alias)(u),
}
if !u.Locale.IsRoot() {
a.Locale = u.Locale
}
if !time.Time(u.UpdatedAt).IsZero() {
a.UpdatedAt = time.Time(u.UpdatedAt).Unix()
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if len(u.claims) == 0 {
return b, nil
}
err = json.Unmarshal(b, &u.claims)
if err != nil {
return nil, fmt.Errorf("jws: invalid map of custom claims %v", u.claims)
}
return json.Marshal(u.claims)
}
func (u *userinfo) UnmarshalJSON(data []byte) error {
type Alias userinfo
a := &struct {
Address *userInfoAddress `json:"address,omitempty"`
*Alias
UpdatedAt int64 `json:"update_at,omitempty"`
}{
Alias: (*Alias)(u),
}
if err := json.Unmarshal(data, &a); err != nil {
return err
}
if a.Address != nil {
u.Address = a.Address
}
u.UpdatedAt = Time(time.Unix(a.UpdatedAt, 0).UTC())
if err := json.Unmarshal(data, &u.claims); err != nil {
return err
}
return nil
}
type UserInfoRequest struct { type UserInfoRequest struct {
AccessToken string `schema:"access_token"` AccessToken string `schema:"access_token"`
} }

View file

@ -8,20 +8,33 @@ import (
) )
func TestUserInfoMarshal(t *testing.T) { func TestUserInfoMarshal(t *testing.T) {
userinfo := NewUserInfo() userinfo := &UserInfo{
userinfo.SetSubject("test") Subject: "test",
userinfo.SetAddress(NewUserInfoAddress("Test 789\nPostfach 2", "", "", "", "", "")) Address: UserInfoAddress{
userinfo.SetEmail("test", true) StreetAddress: "Test 789\nPostfach 2",
userinfo.SetPhone("0791234567", true) },
userinfo.SetName("Test") UserInfoEmail: UserInfoEmail{
userinfo.AppendClaims("private_claim", "test") Email: "test",
EmailVerified: true,
},
UserInfoPhone: UserInfoPhone{
PhoneNumber: "0791234567",
PhoneNumberVerified: true,
},
UserInfoProfile: UserInfoProfile{
Name: "Test",
},
Claims: map[string]any{"private_claim": "test"},
}
marshal, err := json.Marshal(userinfo) marshal, err := json.Marshal(userinfo)
out := NewUserInfo()
assert.NoError(t, err) assert.NoError(t, err)
out := new(UserInfo)
assert.NoError(t, json.Unmarshal(marshal, out)) assert.NoError(t, json.Unmarshal(marshal, out))
assert.Equal(t, userinfo.GetAddress(), out.GetAddress()) assert.Equal(t, userinfo, out)
expected, err := json.Marshal(out) expected, err := json.Marshal(out)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, marshal) assert.Equal(t, expected, marshal)
} }
@ -29,14 +42,14 @@ func TestUserInfoMarshal(t *testing.T) {
func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) { func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("unmarsha email_verified from json bool true", func(t *testing.T) { t.Run("unmarshal email_verified from json bool true", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": true}`) jsonBool := []byte(`{"email": "my@email.com", "email_verified": true}`)
var uie userInfoEmail var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie) err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, userInfoEmail{ assert.Equal(t, UserInfoEmail{
Email: "my@email.com", Email: "my@email.com",
EmailVerified: true, EmailVerified: true,
}, uie) }, uie)
@ -45,11 +58,11 @@ func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
t.Run("unmarsha email_verified from json string true", func(t *testing.T) { t.Run("unmarsha email_verified from json string true", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "true"}`) jsonBool := []byte(`{"email": "my@email.com", "email_verified": "true"}`)
var uie userInfoEmail var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie) err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, userInfoEmail{ assert.Equal(t, UserInfoEmail{
Email: "my@email.com", Email: "my@email.com",
EmailVerified: true, EmailVerified: true,
}, uie) }, uie)
@ -58,11 +71,11 @@ func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
t.Run("unmarsha email_verified from json bool false", func(t *testing.T) { t.Run("unmarsha email_verified from json bool false", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": false}`) jsonBool := []byte(`{"email": "my@email.com", "email_verified": false}`)
var uie userInfoEmail var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie) err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, userInfoEmail{ assert.Equal(t, UserInfoEmail{
Email: "my@email.com", Email: "my@email.com",
EmailVerified: false, EmailVerified: false,
}, uie) }, uie)
@ -71,49 +84,13 @@ func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
t.Run("unmarsha email_verified from json string false", func(t *testing.T) { t.Run("unmarsha email_verified from json string false", func(t *testing.T) {
jsonBool := []byte(`{"email": "my@email.com", "email_verified": "false"}`) jsonBool := []byte(`{"email": "my@email.com", "email_verified": "false"}`)
var uie userInfoEmail var uie UserInfoEmail
err := json.Unmarshal(jsonBool, &uie) err := json.Unmarshal(jsonBool, &uie)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, userInfoEmail{ assert.Equal(t, UserInfoEmail{
Email: "my@email.com", Email: "my@email.com",
EmailVerified: false, EmailVerified: false,
}, uie) }, uie)
}) })
} }
// issue 203 test case.
func Test_userinfo_GetAddress_issue_203(t *testing.T) {
tests := []struct {
name string
data string
}{
{
name: "with address",
data: `{"address":{"street_address":"Test 789\nPostfach 2"},"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
{
name: "without address",
data: `{"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
{
name: "null address",
data: `{"address":null,"email":"test","email_verified":true,"name":"Test","phone_number":"0791234567","phone_number_verified":true,"private_claim":"test","sub":"test"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info := &userinfo{}
err := json.Unmarshal([]byte(tt.data), info)
assert.NoError(t, err)
info.GetAddress().GetCountry() //<- used to panic
// now shortly assure that a marshalling still produces the same as was parsed into the struct
marshal, err := json.Marshal(info)
assert.NoError(t, err)
assert.Equal(t, tt.data, string(marshal))
})
}
}

49
pkg/oidc/util.go Normal file
View file

@ -0,0 +1,49 @@
package oidc
import (
"bytes"
"encoding/json"
"fmt"
)
// mergeAndMarshalClaims merges registered and the custom
// claims map into a single JSON object.
// Registered fields overwrite custom claims.
func mergeAndMarshalClaims(registered any, claims map[string]any) ([]byte, error) {
// Use a buffer for memory re-use, instead off letting
// json allocate a new []byte for every step.
buf := new(bytes.Buffer)
// Marshal the registered claims into JSON
if err := json.NewEncoder(buf).Encode(registered); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
if len(claims) > 0 {
// Merge JSON data into custom claims.
// The full-read action by the decoder resets the buffer
// to zero len, while retaining underlaying cap.
if err := json.NewDecoder(buf).Decode(&claims); err != nil {
return nil, fmt.Errorf("oidc registered claims: %w", err)
}
// Marshal the final result.
if err := json.NewEncoder(buf).Encode(claims); err != nil {
return nil, fmt.Errorf("oidc custom claims: %w", err)
}
}
return buf.Bytes(), nil
}
// unmarshalJSONMulti unmarshals the same JSON data into multiple destinations.
// Each destination must be a pointer, as per json.Unmarshal rules.
// Returns on the first error and destinations may be partly filled with data.
func unmarshalJSONMulti(data []byte, destinations ...any) error {
for _, dst := range destinations {
if err := json.Unmarshal(data, dst); err != nil {
return fmt.Errorf("oidc: %w into %T", err, dst)
}
}
return nil
}

147
pkg/oidc/util_test.go Normal file
View file

@ -0,0 +1,147 @@
package oidc
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type jsonErrorTest struct{}
func (jsonErrorTest) MarshalJSON() ([]byte, error) {
return nil, errors.New("test")
}
func Test_mergeAndMarshalClaims(t *testing.T) {
type args struct {
registered any
claims map[string]any
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "encoder error",
args: args{
registered: jsonErrorTest{},
},
wantErr: true,
},
{
name: "no claims",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
},
want: "{\"foo\":\"bar\"}\n",
},
{
name: "with claims",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
claims: map[string]any{
"bar": "foo",
},
},
want: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
},
{
name: "registered overwrites custom",
args: args{
registered: struct {
Foo string `json:"foo,omitempty"`
}{
Foo: "bar",
},
claims: map[string]any{
"foo": "Hello, World!",
},
},
want: "{\"foo\":\"bar\"}\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := mergeAndMarshalClaims(tt.args.registered, tt.args.claims)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, string(got))
})
}
}
func Test_unmarshalJSONMulti(t *testing.T) {
type dst struct {
Foo string `json:"foo,omitempty"`
}
type args struct {
data string
destinations []any
}
tests := []struct {
name string
args args
want []any
wantErr bool
}{
{
name: "error",
args: args{
data: "~!~~",
destinations: []any{
&dst{},
&map[string]any{},
},
},
want: []any{
&dst{},
&map[string]any{},
},
wantErr: true,
},
{
name: "success",
args: args{
data: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
destinations: []any{
&dst{},
&map[string]any{},
},
},
want: []any{
&dst{Foo: "bar"},
&map[string]any{
"foo": "bar",
"bar": "foo",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := unmarshalJSONMulti([]byte(tt.args.data), tt.args.destinations...)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tt.want, tt.args.destinations)
})
}
}

View file

@ -32,6 +32,12 @@ type ClaimsSignature interface {
SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
} }
type IDClaims interface {
Claims
GetSignatureAlgorithm() jose.SignatureAlgorithm
GetAccessTokenHash() string
}
var ( var (
ErrParse = errors.New("parsing of request failed") ErrParse = errors.New("parsing of request failed")
ErrIssuerInvalid = errors.New("issuer does not match") ErrIssuerInvalid = errors.New("issuer does not match")

View file

@ -371,7 +371,7 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
if idTokenHint == "" { if idTokenHint == "" {
return "", nil return "", nil
} }
claims, err := VerifyIDTokenHint(ctx, idTokenHint, verifier) claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
if err != nil { if err != nil {
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " + return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
"If you have any questions, you may contact the administrator of the application.") "If you have any questions, you may contact the administrator of the application.")

View file

@ -263,7 +263,7 @@ func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *g
} }
// SetIntrospectionFromToken mocks base method. // SetIntrospectionFromToken mocks base method.
func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 oidc.IntrospectionResponse, arg2, arg3, arg4 string) error { func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 *oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -277,7 +277,7 @@ func (mr *MockStorageMockRecorder) SetIntrospectionFromToken(arg0, arg1, arg2, a
} }
// SetUserinfoFromScopes mocks base method. // SetUserinfoFromScopes mocks base method.
func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3 string, arg4 []string) error { func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3 string, arg4 []string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -291,7 +291,7 @@ func (mr *MockStorageMockRecorder) SetUserinfoFromScopes(arg0, arg1, arg2, arg3,
} }
// SetUserinfoFromToken mocks base method. // SetUserinfoFromToken mocks base method.
func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 oidc.UserInfoSetter, arg2, arg3, arg4 string) error { func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3, arg4 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View file

@ -59,7 +59,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
RedirectURI: ender.DefaultLogoutRedirectURI(), RedirectURI: ender.DefaultLogoutRedirectURI(),
} }
if req.IdTokenHint != "" { if req.IdTokenHint != "" {
claims, err := VerifyIDTokenHint(ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx)) claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
} }

View file

@ -96,7 +96,7 @@ type TokenExchangeStorage interface {
// SetUserinfoFromTokenExchangeRequest will be called during id token creation. // SetUserinfoFromTokenExchangeRequest will be called during id token creation.
// Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc. // Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo oidc.UserInfoSetter, request TokenExchangeRequest) error SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request TokenExchangeRequest) error
} }
// TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens // TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens
@ -111,9 +111,9 @@ var ErrInvalidRefreshToken = errors.New("invalid_refresh_token")
type OPStorage interface { type OPStorage interface {
GetClientByClientID(ctx context.Context, clientID string) (Client, error) GetClientByClientID(ctx context.Context, clientID string) (Client, error)
AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, userID, clientID string, scopes []string) error SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, tokenID, subject, origin string) error SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
SetIntrospectionFromToken(ctx context.Context, userinfo oidc.IntrospectionResponse, tokenID, subject, clientID string) error SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error)
GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)

View file

@ -129,7 +129,7 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetPrivateClaims(privateClaims) claims.Claims = privateClaims
} }
signingKey, err := storage.SigningKey(ctx) signingKey, err := storage.SigningKey(ctx)
if err != nil { if err != nil {
@ -169,7 +169,7 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetAccessTokenHash(atHash) claims.AccessTokenHash = atHash
if !client.IDTokenUserinfoClaimsAssertion() { if !client.IDTokenUserinfoClaimsAssertion() {
scopes = removeUserinfoScopes(scopes) scopes = removeUserinfoScopes(scopes)
} }
@ -178,26 +178,26 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v
tokenExchangeRequest, okReq := request.(TokenExchangeRequest) tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
teStorage, okStorage := storage.(TokenExchangeStorage) teStorage, okStorage := storage.(TokenExchangeStorage)
if okReq && okStorage { if okReq && okStorage {
userInfo := oidc.NewUserInfo() userInfo := new(oidc.UserInfo)
err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest) err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetUserinfo(userInfo) claims.SetUserInfo(userInfo)
} else if len(scopes) > 0 { } else if len(scopes) > 0 {
userInfo := oidc.NewUserInfo() userInfo := new(oidc.UserInfo)
err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes) err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetUserinfo(userInfo) claims.SetUserInfo(userInfo)
} }
if code != "" { if code != "" {
codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm()) codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
if err != nil { if err != nil {
return "", err return "", err
} }
claims.SetCodeHash(codeHash) claims.CodeHash = codeHash
} }
signer, err := SignerFromKey(signingKey) signer, err := SignerFromKey(signingKey)
if err != nil { if err != nil {

View file

@ -280,9 +280,9 @@ func GetTokenIDAndSubjectFromToken(
) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) { ) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) {
switch tokenType { switch tokenType {
case oidc.AccessTokenType: case oidc.AccessTokenType:
var accessTokenClaims oidc.AccessTokenClaims var accessTokenClaims *oidc.AccessTokenClaims
tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token) tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token)
claims = accessTokenClaims.GetClaims() claims = accessTokenClaims.Claims
case oidc.RefreshTokenType: case oidc.RefreshTokenType:
refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token) refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token)
if err != nil { if err != nil {
@ -291,12 +291,12 @@ func GetTokenIDAndSubjectFromToken(
tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true
case oidc.IDTokenType: case oidc.IDTokenType:
idTokenClaims, err := VerifyIDTokenHint(ctx, token, exchanger.IDTokenHintVerifier(ctx)) idTokenClaims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, token, exchanger.IDTokenHintVerifier(ctx))
if err != nil { if err != nil {
break break
} }
tokenIDOrToken, subject, claims, ok = token, idTokenClaims.GetSubject(), idTokenClaims.GetClaims(), true tokenIDOrToken, subject, claims, ok = token, idTokenClaims.Subject, idTokenClaims.Claims, true
} }
if !ok { if !ok {
@ -380,7 +380,7 @@ func CreateTokenExchangeResponse(
}, nil }, nil
} }
func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, oidc.AccessTokenClaims, bool) { func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, *oidc.AccessTokenClaims, bool) {
tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
if err == nil { if err == nil {
splitToken := strings.Split(tokenIDSubject, ":") splitToken := strings.Split(tokenIDSubject, ":")
@ -390,10 +390,10 @@ func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider,
return splitToken[0], splitToken[1], nil, true return splitToken[0], splitToken[1], nil, true
} }
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx)) accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil { if err != nil {
return "", "", nil, false return "", "", nil, false
} }
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), accessTokenClaims, true return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims, true
} }

View file

@ -28,7 +28,7 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *
} }
func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) { func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) {
response := oidc.NewIntrospectionResponse() response := new(oidc.IntrospectionResponse)
token, clientID, err := ParseTokenIntrospectionRequest(r, introspector) token, clientID, err := ParseTokenIntrospectionRequest(r, introspector)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
@ -44,7 +44,7 @@ func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspecto
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
return return
} }
response.SetActive(true) response.Active = true
httphelper.MarshalJSON(w, response) httphelper.MarshalJSON(w, response)
} }

View file

@ -151,9 +151,9 @@ func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider Use
} }
return splitToken[0], splitToken[1], true return splitToken[0], splitToken[1], true
} }
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx)) accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil { if err != nil {
return "", "", false return "", "", false
} }
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
} }

View file

@ -34,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
http.Error(w, "access token invalid", http.StatusUnauthorized) http.Error(w, "access token invalid", http.StatusUnauthorized)
return return
} }
info := oidc.NewUserInfo() info := new(oidc.UserInfo)
err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin")) err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
if err != nil { if err != nil {
httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden) httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
@ -81,9 +81,9 @@ func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider
} }
return splitToken[0], splitToken[1], true return splitToken[0], splitToken[1], true
} }
accessTokenClaims, err := VerifyAccessToken(ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx)) accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
if err != nil { if err != nil {
return "", "", false return "", "", false
} }
return accessTokenClaims.GetTokenID(), accessTokenClaims.GetSubject(), true return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
} }

View file

@ -68,28 +68,28 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok
} }
// VerifyAccessToken validates the access token (issuer, signature and expiration) // VerifyAccessToken validates the access token (issuer, signature and expiration)
func VerifyAccessToken(ctx context.Context, token string, v AccessTokenVerifier) (oidc.AccessTokenClaims, error) { func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) {
claims := oidc.EmptyAccessTokenClaims() var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
payload, err := oidc.ParseToken(decrypted, claims) payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err return nilClaims, err
} }
return claims, nil return claims, nil

View file

@ -74,40 +74,40 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
// VerifyIDTokenHint validates the id token according to // VerifyIDTokenHint validates the id token according to
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func VerifyIDTokenHint(ctx context.Context, token string, v IDTokenHintVerifier) (oidc.IDTokenClaims, error) { func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) {
claims := oidc.EmptyIDTokenClaims() var nilClaims C
decrypted, err := oidc.DecryptToken(token) decrypted, err := oidc.DecryptToken(token)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
payload, err := oidc.ParseToken(decrypted, claims) payload, err := oidc.ParseToken(decrypted, &claims)
if err != nil { if err != nil {
return nil, err return nilClaims, err
} }
if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { if err = oidc.CheckExpiration(claims, v.Offset()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil {
return nil, err return nilClaims, err
} }
if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil {
return nil, err return nilClaims, err
} }
return claims, nil return claims, nil
} }