diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 79ff704..1efdcf8 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,6 +9,16 @@ updates: commit-message: prefix: chore include: scope +- package-ecosystem: gomod + target-branch: "2.12.x" + directory: "/" + schedule: + interval: daily + time: '04:00' + open-pull-requests-limit: 10 + commit-message: + prefix: chore + include: scope - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index a8106ae..27fa244 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -29,7 +29,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 # Override language selection by uncommenting this and choosing your languages with: languages: go @@ -37,7 +37,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -51,4 +51,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/issue.yml b/.github/workflows/issue.yml index 362443d..62fd01d 100644 --- a/.github/workflows/issue.yml +++ b/.github/workflows/issue.yml @@ -4,15 +4,40 @@ on: issues: types: - opened + pull_request_target: + types: + - opened jobs: add-to-project: - name: Add issue to project + name: Add issue and community pr to project runs-on: ubuntu-latest steps: - - uses: actions/add-to-project@v0.5.0 + - name: add issue + uses: actions/add-to-project@v0.5.0 + if: ${{ github.event_name == 'issues' }} with: # You can target a repository in a different organization # to the issue project-url: https://github.com/orgs/zitadel/projects/2 github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} + - uses: tspascoal/get-user-teams-membership@v3 + id: checkUserMember + if: github.actor != 'dependabot[bot]' + with: + username: ${{ github.actor }} + GITHUB_TOKEN: ${{ secrets.ADD_TO_PROJECT_PAT }} + - name: add pr + uses: actions/add-to-project@v0.5.0 + if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'engineers')}} + with: + # You can target a repository in a different organization + # to the issue + project-url: https://github.com/orgs/zitadel/projects/2 + github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} + - uses: actions-ecosystem/action-add-labels@v1.1.3 + if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'staff')}} + with: + github_token: ${{ secrets.ADD_TO_PROJECT_PAT }} + labels: | + os-contribution diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 644b23f..6f92575 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -2,6 +2,7 @@ name: Release on: push: branches: + - "2.11.x" - main - next tags-ignore: @@ -22,11 +23,11 @@ jobs: steps: - uses: actions/checkout@v4 - name: Setup go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/... - - uses: codecov/codecov-action@v3.1.4 + - uses: codecov/codecov-action@v4.0.1 with: file: ./profile.cov name: codecov-go diff --git a/.releaserc.js b/.releaserc.js index e8eea8e..c87b1d1 100644 --- a/.releaserc.js +++ b/.releaserc.js @@ -1,5 +1,6 @@ module.exports = { branches: [ + {name: "2.11.x"}, {name: "main"}, {name: "next", prerelease: true}, ], diff --git a/README.md b/README.md index f9ec7ce..7f1a610 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid | Code Flow | yes | yes | OpenID Connect Core 1.0, [Section 3.1][1] | | Implicit Flow | no[^1] | yes | OpenID Connect Core 1.0, [Section 3.2][2] | | Hybrid Flow | no | not yet | OpenID Connect Core 1.0, [Section 3.3][3] | -| Client Credentials | not yet | yes | OpenID Connect Core 1.0, [Section 9][4] | +| Client Credentials | yes | yes | OpenID Connect Core 1.0, [Section 9][4] | | Refresh Token | yes | yes | OpenID Connect Core 1.0, [Section 12][5] | | Discovery | yes | yes | OpenID Connect [Discovery][6] 1.0 | | JWT Profile | yes | yes | [RFC 7523][7] | diff --git a/SECURITY.md b/SECURITY.md index d682630..a32b842 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,6 +1,6 @@ # Security Policy -At ZITADEL we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team. +Please refer to the security policy [on zitadel/zitadel](https://github.com/zitadel/zitadel/blob/main/SECURITY.md) which is applicable for all open source repositories of our organization. ## Supported Versions @@ -9,43 +9,12 @@ We currently support the following version of the OIDC framework: | Version | Supported | Branch | Details | | -------- | ------------------ | ----------- | ------------------------------------ | | 0.x.x | :x: | | not maintained | -| <1.13 | :x: | | not maintained | -| 1.13.x | :lock: :warning: | [1.13.x][1] | security only, [community effort][2] | -| 2.x.x | :heavy_check_mark: | [main][3] | supported | -| 3.0.0-xx | :white_check_mark: | [next][4] | [developement branch][5] | +| <2.11 | :x: | | not maintained | +| 2.11.x | :lock: :warning: | [2.11.x][1] | security only, [community effort][2] | +| 3.x.x | :heavy_check_mark: | [main][3] | supported | +| 4.0.0-xx | :white_check_mark: | [next][4] | [development branch] | -[1]: https://github.com/zitadel/oidc/tree/1.13.x -[2]: https://github.com/zitadel/oidc/discussions/378 +[1]: https://github.com/zitadel/oidc/tree/2.11.x +[2]: https://github.com/zitadel/oidc/discussions/458 [3]: https://github.com/zitadel/oidc/tree/main [4]: https://github.com/zitadel/oidc/tree/next -[5]: https://github.com/zitadel/oidc/milestone/2 - -## Reporting a vulnerability - -To file a incident, please disclose by email to security@zitadel.com with the security details. - -At the moment GPG encryption is no yet supported, however you may sign your message at will. - -### When should I report a vulnerability - -* You think you discovered a ... - * ... potential security vulnerability in the SDK - * ... vulnerability in another project that this SDK bases on -* For projects with their own vulnerability reporting and disclosure process, please report it directly there - -### When should I NOT report a vulnerability - -* You need help applying security related updates -* Your issue is not security related - -## Security Vulnerability Response - -TBD - -## Public Disclosure - -All accepted and mitigated vulnerabilities will be published on the [Github Security Page](https://github.com/zitadel/oidc/security/advisories) - -### Timing - -We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the disclosures the time frame can range from 7 to 90 days. diff --git a/example/client/device/device.go b/example/client/device/device.go index bea6134..78ed2c8 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -1,3 +1,37 @@ +// Command device is an example Oauth2 Device Authorization Grant app. +// It creates a new Device Authorization request on the Issuer and then polls for tokens. +// The user is then prompted to visit a URL and enter the user code. +// Or, the complete URL can be used instead to omit manual entry. +// In practice then can be a "magic link" in the form or a QR. +// +// The following environment variables are used for configuration: +// +// ISSUER: URL to the OP, required. +// CLIENT_ID: ID of the application, required. +// CLIENT_SECRET: Secret to authenticate the app using basic auth. Only required if the OP expects this type of authentication. +// KEY_PATH: Path to a private key file, used to for JWT authentication of the App. Only required if the OP expects this type of authentication. +// SCOPES: Scopes of the Authentication Request. Optional. +// +// Basic usage: +// +// cd example/client/device +// export ISSUER="http://localhost:9000" CLIENT_ID="246048465824634593@demo" +// +// Get an Access Token: +// +// SCOPES="email profile" go run . +// +// Get an Access Token and ID Token: +// +// SCOPES="email profile openid" go run . +// +// Get an Access Token and Refresh Token +// +// SCOPES="email profile offline_access" go run . +// +// Get Access, Refresh and ID Tokens: +// +// SCOPES="email profile offline_access openid" go run . package main import ( @@ -57,5 +91,5 @@ func main() { if err != nil { logrus.Fatal(err) } - logrus.Infof("successfully obtained token: %v", token) + logrus.Infof("successfully obtained token: %#v", token) } diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 74018da..baa2662 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer handler := http.Handler(provider) if wrapServer { - handler = op.NewLegacyServer(provider, *op.DefaultEndpoints) + handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints)) } // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index 1a04f4f..b556828 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -90,6 +90,10 @@ func (s *publicKey) Key() any { } func NewStorage(userStore UserStore) *Storage { + return NewStorageWithClients(userStore, clients) +} + +func NewStorageWithClients(userStore UserStore, clients map[string]*Client) *Storage { key, _ := rsa.GenerateKey(rand.Reader, 2048) return &Storage{ authRequests: make(map[string]*AuthRequest), diff --git a/go.mod b/go.mod index d3245eb..d1c5f2b 100644 --- a/go.mod +++ b/go.mod @@ -3,38 +3,39 @@ module github.com/zitadel/oidc/v3 go 1.19 require ( - github.com/go-chi/chi/v5 v5.0.10 - github.com/go-jose/go-jose/v3 v3.0.0 + github.com/bmatcuk/doublestar/v4 v4.6.1 + github.com/go-chi/chi/v5 v5.0.12 + github.com/go-jose/go-jose/v3 v3.0.2 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 - github.com/google/uuid v1.3.1 - github.com/gorilla/securecookie v1.1.1 + github.com/google/uuid v1.6.0 + github.com/gorilla/securecookie v1.1.2 github.com/jeremija/gosubmit v0.2.7 github.com/muhlemmer/gu v0.3.1 github.com/muhlemmer/httpforwarded v0.1.0 github.com/rs/cors v1.10.1 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 - github.com/zitadel/logging v0.4.0 + github.com/zitadel/logging v0.5.0 github.com/zitadel/schema v1.3.0 - go.opentelemetry.io/otel v1.19.0 - go.opentelemetry.io/otel/trace v1.19.0 + go.opentelemetry.io/otel v1.24.0 + go.opentelemetry.io/otel/trace v1.24.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 - golang.org/x/oauth2 v0.13.0 - golang.org/x/text v0.13.0 + golang.org/x/oauth2 v0.17.0 + golang.org/x/text v0.14.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-logr/logr v1.2.4 // indirect + github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - go.opentelemetry.io/otel/metric v1.19.0 // indirect - golang.org/x/crypto v0.14.0 // indirect - golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect + go.opentelemetry.io/otel/metric v1.24.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.31.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index c57f8da..f84f80e 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,15 @@ +github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= +github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= -github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo= -github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= +github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= +github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-jose/go-jose/v3 v3.0.2 h1:2Edjn8Nrb44UvTdp84KU0bBPs1cO7noRCybtS3eJEUQ= +github.com/go-jose/go-jose/v3 v3.0.2/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -17,19 +19,20 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-github/v31 v31.0.0 h1:JJUxlP9lFK+ziXKimTCprajMApV1ecWD4NB6CCb0plo= github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= -github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= @@ -45,59 +48,81 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA= -github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA= +github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= -go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs= -go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY= -go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE= -go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319EUrDVLrt7jqt8= -go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg= -go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= +go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= +go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= +go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= +go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= +go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= +go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY= -golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= +golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= +golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index ec4d57b..ce77f5e 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/slog" + "golang.org/x/oauth2" "github.com/zitadel/oidc/v3/example/server/exampleop" "github.com/zitadel/oidc/v3/example/server/storage" @@ -217,6 +218,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, targetURL, []string{"openid", "email", "profile", "offline_access"}, rp.WithPKCE(cookieHandler), + rp.WithAuthStyle(oauth2.AuthStyleInHeader), rp.WithVerifierOpts( rp.WithIssuedAtOffset(5*time.Second), rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"), @@ -323,6 +325,31 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, return provider, tokens } +func TestClientCredentials(t *testing.T) { + targetURL := "http://local-site" + exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) + var dh deferredHandler + opServer := httptest.NewServer(&dh) + defer opServer.Close() + t.Logf("auth server at %s", opServer.URL) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true) + + provider, err := rp.NewRelyingPartyOIDC( + CTX, + opServer.URL, + "sid1", + "verysecret", + targetURL, + []string{"openid"}, + ) + require.NoError(t, err, "new rp") + + token, err := rp.ClientCredentials(CTX, provider, nil) + require.NoError(t, err, "ClientCredentials call") + require.NotNil(t, token) + assert.NotEmpty(t, token.AccessToken) +} + func TestErrorFromPromptNone(t *testing.T) { jar, err := cookiejar.New(nil) require.NoError(t, err, "create cookie jar") diff --git a/pkg/client/rp/errors.go b/pkg/client/rp/errors.go new file mode 100644 index 0000000..b95420b --- /dev/null +++ b/pkg/client/rp/errors.go @@ -0,0 +1,5 @@ +package rp + +import "errors" + +var ErrRelyingPartyNotSupportRevokeCaller = errors.New("RelyingParty does not support RevokeCaller") diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index c6ae2db..d4bc13c 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -4,16 +4,16 @@ import ( "context" "encoding/base64" "errors" - "fmt" "net/http" "net/url" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "github.com/google/uuid" "github.com/zitadel/logging" "golang.org/x/exp/slog" "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" "github.com/zitadel/oidc/v3/pkg/client" httphelper "github.com/zitadel/oidc/v3/pkg/http" @@ -66,19 +66,28 @@ type RelyingParty interface { // IDTokenVerifier returns the verifier used for oidc id_token verification IDTokenVerifier() *IDTokenVerifier - // ErrorHandler returns the handler used for callback errors + // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) // Logger from the context, or a fallback if set. Logger(context.Context) (logger *slog.Logger, ok bool) } +type HasUnauthorizedHandler interface { + // UnauthorizedHandler returns the handler used for unauthorized errors + UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string) +} + type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) +type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string) var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) } +var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) { + http.Error(w, desc, http.StatusUnauthorized) +} type relyingParty struct { issuer string @@ -91,11 +100,14 @@ type relyingParty struct { httpClient *http.Client cookieHandler *httphelper.CookieHandler - errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier *IDTokenVerifier - verifierOpts []VerifierOption - signer jose.Signer - logger *slog.Logger + oauthAuthStyle oauth2.AuthStyle + + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string) + idTokenVerifier *IDTokenVerifier + verifierOpts []VerifierOption + signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -156,6 +168,13 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) { + if rp.unauthorizedHandler == nil { + rp.unauthorizedHandler = DefaultUnauthorizedHandler + } + return rp.unauthorizedHandler +} + func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { logger, ok = logging.FromContext(ctx) if ok { @@ -169,9 +188,11 @@ func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok boo // it will use the AuthURL and TokenURL set in config func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) { rp := &relyingParty{ - oauthConfig: config, - httpClient: httphelper.DefaultHTTPClient, - oauth2Only: true, + oauthConfig: config, + httpClient: httphelper.DefaultHTTPClient, + oauth2Only: true, + unauthorizedHandler: DefaultUnauthorizedHandler, + oauthAuthStyle: oauth2.AuthStyleAutoDetect, } for _, optFunc := range options { @@ -180,9 +201,12 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart } } + rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle + // avoid races by calling these early - _ = rp.IDTokenVerifier() // sets idTokenVerifier - _ = rp.ErrorHandler() // sets errorHandler + _ = rp.IDTokenVerifier() // sets idTokenVerifier + _ = rp.ErrorHandler() // sets errorHandler + _ = rp.UnauthorizedHandler() // sets unauthorizedHandler return rp, nil } @@ -199,8 +223,9 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re RedirectURL: redirectURI, Scopes: scopes, }, - httpClient: httphelper.DefaultHTTPClient, - oauth2Only: false, + httpClient: httphelper.DefaultHTTPClient, + oauth2Only: false, + oauthAuthStyle: oauth2.AuthStyleAutoDetect, } for _, optFunc := range options { @@ -217,9 +242,13 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re rp.oauthConfig.Endpoint = endpoints.Endpoint rp.endpoints = endpoints + rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle + rp.endpoints.Endpoint.AuthStyle = rp.oauthAuthStyle + // avoid races by calling these early - _ = rp.IDTokenVerifier() // sets idTokenVerifier - _ = rp.ErrorHandler() // sets errorHandler + _ = rp.IDTokenVerifier() // sets idTokenVerifier + _ = rp.ErrorHandler() // sets errorHandler + _ = rp.UnauthorizedHandler() // sets unauthorizedHandler return rp, nil } @@ -268,6 +297,20 @@ func WithErrorHandler(errorHandler ErrorHandler) Option { } } +func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option { + return func(rp *relyingParty) error { + rp.unauthorizedHandler = unauthorizedHandler + return nil + } +} + +func WithAuthStyle(oauthAuthStyle oauth2.AuthStyle) Option { + return func(rp *relyingParty) error { + rp.oauthAuthStyle = oauthAuthStyle + return nil + } +} + func WithVerifierOpts(opts ...VerifierOption) Option { return func(rp *relyingParty) error { rp.verifierOpts = opts @@ -355,13 +398,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam state := stateFn() if err := trySetStateCookie(w, state, rp); err != nil { - http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp) return } if rp.IsPKCE() { codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) if err != nil { - http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp) return } opts = append(opts, WithCodeChallenge(codeChallenge)) @@ -416,17 +459,39 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP return verifyTokenResponse[C](ctx, token, rp) } +// ClientCredentials requests an access token using the `client_credentials` grant, +// as defined in [RFC 6749, section 4.4]. +// +// As there is no user associated to the request an ID Token can never be returned. +// Client Credentials are undefined in OpenID Connect and is a pure OAuth2 grant. +// Furthermore the server SHOULD NOT return a refresh token. +// +// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 +func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials") + ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) + config := clientcredentials.Config{ + ClientID: rp.OAuthConfig().ClientID, + ClientSecret: rp.OAuthConfig().ClientSecret, + TokenURL: rp.OAuthConfig().Endpoint.TokenURL, + Scopes: rp.OAuthConfig().Scopes, + EndpointParams: endpointParams, + AuthStyle: rp.OAuthConfig().Endpoint.AuthStyle, + } + return config.Token(ctx) +} + type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) // CodeExchangeHandler extends the `CodeExchange` method with a http handler // including cookie handling for secure `state` transfer // and optional PKCE code verifier checking. -// Custom paramaters can optionally be set to the token URL. +// Custom parameters can optionally be set to the token URL. func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state, err := tryReadStateCookie(w, r, rp) if err != nil { - http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp) return } params := r.URL.Query() @@ -442,7 +507,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.IsPKCE() { codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) if err != nil { - http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) @@ -451,14 +516,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.Signer() != nil { assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer()) if err != nil { - http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) } tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) if err != nil { - http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp) return } callback(w, r, tokens, state, rp) @@ -478,7 +543,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { - http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp) return } f(w, r, tokens, state, rp, info) @@ -545,9 +610,8 @@ type Endpoints struct { func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { return Endpoints{ Endpoint: oauth2.Endpoint{ - AuthURL: discoveryConfig.AuthorizationEndpoint, - AuthStyle: oauth2.AuthStyleAutoDetect, - TokenURL: discoveryConfig.TokenEndpoint, + AuthURL: discoveryConfig.AuthorizationEndpoint, + TokenURL: discoveryConfig.TokenEndpoint, }, IntrospectURL: discoveryConfig.IntrospectionEndpoint, UserinfoURL: discoveryConfig.UserinfoEndpoint, @@ -703,5 +767,13 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { return client.CallRevokeEndpoint(ctx, request, nil, rc) } - return fmt.Errorf("RelyingParty does not support RevokeCaller") + return ErrRelyingPartyNotSupportRevokeCaller +} + +func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) { + if rp, ok := rp.(HasUnauthorizedHandler); ok { + rp.UnauthorizedHandler()(w, r, desc, state) + return + } + http.Error(w, desc, http.StatusUnauthorized) } diff --git a/pkg/crypto/key.go b/pkg/crypto/key.go index d75d1ab..79e2046 100644 --- a/pkg/crypto/key.go +++ b/pkg/crypto/key.go @@ -4,14 +4,19 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" ) -func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { - block, _ := pem.Decode(priv) - b := block.Bytes - key, err := x509.ParsePKCS1PrivateKey(b) +func BytesToPrivateKey(b []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(b) + if block == nil { + return nil, errors.New("PEM decode failed") + } + + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, err } + return key, nil } diff --git a/pkg/crypto/key_test.go b/pkg/crypto/key_test.go new file mode 100644 index 0000000..23ebdc0 --- /dev/null +++ b/pkg/crypto/key_test.go @@ -0,0 +1,62 @@ +package crypto_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/oidc/v3/pkg/crypto" +) + +func TestBytesToPrivateKey(tt *testing.T) { + tt.Run("PEMDecodeError", func(t *testing.T) { + _, err := crypto.BytesToPrivateKey([]byte("The non-PEM sequence")) + assert.EqualError(t, err, "PEM decode failed") + }) + + tt.Run("InvalidKeyFormat", func(t *testing.T) { + _, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCfaDB7pK/fmP/I +7IusSK8lTCBnPZghqIbVLt2QHYAMoEF1CaF4F4rxo2vl1Mt8gwsq4T3osQFZMvnL +YHb7KNyUoJgTjLxJQADv2u4Q3U38heAzK5Tp4ry4MCnuyJIqAPK1GiruwEq4zQrx ++WzVix8otO37SuW9tzklqlNGMiAYBL0TBKHvS5XMbjP1idBMB8erMz29w/TVQnEB +Kj0vCdZjrbVPKygptt5kcSrL5f4xCZwU+ufz7cp0GLwpRMJ+shG9YJJFBxb0itPF +sy51vAyEtdBC7jgAU96ZVeQ06nryDq1D2EpoVMElqNyL46Jo3lnKbGquGKzXzQYU +BN32/scDAgMBAAECggEBAJE/mo3PLgILo2YtQ8ekIxNVHmF0Gl7w9IrjvTdH6hmX +HI3MTLjkmtI7GmG9V/0IWvCjdInGX3grnrjWGRQZ04QKIQgPQLFuBGyJjEsJm7nx +MqztlS7YTyV1nX/aenSTkJO8WEpcJLnm+4YoxCaAMdAhrIdBY71OamALpv1bRysa +FaiCGcemT2yqZn0GqIS8O26Tz5zIqrTN2G1eSmgh7DG+7FoddMz35cute8R10xUG +hF5YU+6fcXiRQ/Kh7nlxelPGqdZFPMk7LpVHzkQKwdJ+N0P23lPDIfNsvpG1n0OP +3g5km7gHSrSU2yZ3eFl6DB9x1IFNS9BaQQuSxYJtKwECgYEA1C8jjzpXZDLvlYsV +2jlMzkrbsIrX2dzblVrNsPs2jRbjYU8mg2DUDO6lOhtxHfqZG6sO+gmWi/zvoy9l +yolGbXe1Jqx66p9fznIcecSwar8+ACa356Wk74Nt1PlBOfCMqaJnYLOLaFJa29Vy +u5ClZVzKd5AVXl7yFVd4XfLv/WECgYEAwFMMtFoasdF92c0d31rZ1uoPOtFz6xq6 +uQggdm5zzkhnfwUAGqppS/u1CHcJ7T/74++jLbFTsaohGr4jEzWSGvJpomEUChy3 +r25YofMclUhJ5pCEStsLtqiCR1Am6LlI8HMdBEP1QDgEC5q8bQW4+UHuew1E1zxz +osZOhe09WuMCgYEA0G9aFCnwjUqIFjQiDFP7gi8BLqTFs4uE3Wvs4W11whV42i+B +ms90nxuTjchFT3jMDOT1+mOO0wdudLRr3xEI8SIF/u6ydGaJG+j21huEXehtxIJE +aDdNFcfbDbqo+3y1ATK7MMBPMvSrsoY0hdJq127WqasNgr3sO1DIuima3SECgYEA +nkM5TyhekzlbIOHD1UsDu/D7+2DkzPE/+oePfyXBMl0unb3VqhvVbmuBO6gJiSx/ +8b//PdiQkMD5YPJaFrKcuoQFHVRZk0CyfzCEyzAts0K7XXpLAvZiGztriZeRjSz7 +srJnjF0H8oKmAY6hw+1Tm/n/b08p+RyL48TgVSE2vhUCgYA3BWpkD4PlCcn/FZsq +OrLFyFXI6jIaxskFtsRW1IxxIlAdZmxfB26P/2gx6VjLdxJI/RRPkJyEN2dP7CbR +BDjb565dy1O9D6+UrY70Iuwjz+OcALRBBGTaiF2pLn6IhSzNI2sy/tXX8q8dBlg9 +OFCrqT/emes3KytTPfa5NZtYeQ== +-----END PRIVATE KEY-----`)) + assert.EqualError(t, err, "x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)") + }) + + tt.Run("Ok", func(t *testing.T) { + key, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu +KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm +o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k +TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7 +9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy +v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs +/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00 +-----END RSA PRIVATE KEY-----`)) + assert.NoError(t, err) + assert.NotNil(t, key) + }) +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index d8372b8..0e7152c 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -3,6 +3,7 @@ package oidc import ( "database/sql/driver" "encoding/json" + "errors" "fmt" "reflect" "strings" @@ -76,8 +77,23 @@ func (l *Locale) MarshalJSON() ([]byte, error) { return json.Marshal(tag) } +// UnmarshalJSON implements json.Unmarshaler. +// When [language.ValueError] is encountered, the containing tag will be set +// to an empty value (language "und") and no error will be returned. +// This state can be checked with the `l.Tag().IsRoot()` method. func (l *Locale) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &l.tag) + err := json.Unmarshal(data, &l.tag) + if err == nil { + return nil + } + + // catch "well-formed but unknown" errors + var target language.ValueError + if errors.As(err, &target) { + l.tag = language.Tag{} + return nil + } + return err } type Locales []language.Tag diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index af4f113..df93a73 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -208,20 +208,46 @@ func TestLocale_MarshalJSON(t *testing.T) { } func TestLocale_UnmarshalJSON(t *testing.T) { - type a struct { + type dst struct { Locale *Locale `json:"locale,omitempty"` } - want := a{ - Locale: NewLocale(language.Afrikaans), + tests := []struct { + name string + input string + want dst + wantErr bool + }{ + { + name: "afrikaans, ok", + input: `{"locale": "af"}`, + want: dst{ + Locale: NewLocale(language.Afrikaans), + }, + }, + { + name: "gb, ignored", + input: `{"locale": "gb"}`, + want: dst{ + Locale: &Locale{}, + }, + }, + { + name: "bad form, error", + input: `{"locale": "g!!!!!"}`, + wantErr: true, + }, } - const input = `{"locale": "af"}` - var got a - - require.NoError(t, - json.Unmarshal([]byte(input), &got), - ) - assert.Equal(t, want, got) + for _, tt := range tests { + var got dst + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } } func TestParseLocales(t *testing.T) { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index 42fbb20..fe28857 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -57,7 +57,7 @@ var ( ErrNonceInvalid = errors.New("nonce does not match") ErrAcrInvalid = errors.New("acr is invalid") ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") - ErrAuthTimeToOld = errors.New("auth time of token is to old") + ErrAuthTimeToOld = errors.New("auth time of token is too old") ErrAtHash = errors.New("at_hash does not correspond to access token") ) diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 7ef06a8..7058ebc 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -7,10 +7,10 @@ import ( "net" "net/http" "net/url" - "path" "strings" "time" + "github.com/bmatcuk/doublestar/v4" httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" str "github.com/zitadel/oidc/v3/pkg/strings" @@ -138,20 +138,20 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request") } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request") } if requestObject.Issuer != requestObject.ClientID { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request") } if !str.Contains(requestObject.Audience, issuer) { - return oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience") } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return err + return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error()) } CopyRequestObjectToAuthRequest(authReq, requestObject) return nil @@ -283,7 +283,7 @@ func checkURIAgainstRedirects(client Client, uri string) error { } if globClient, ok := client.(HasRedirectGlobs); ok { for _, uriGlob := range globClient.RedirectURIGlobs() { - isMatch, err := path.Match(uriGlob, uri) + isMatch, err := doublestar.Match(uriGlob, uri) if err != nil { return oidc.ErrServerError().WithParent(err) } @@ -391,9 +391,9 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie return "", nil } claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier) - if err != nil { + if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) { return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " + - "If you have any questions, you may contact the administrator of the application.") + "If you have any questions, you may contact the administrator of the application.").WithParent(err) } return claims.GetSubject(), nil } diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index db70fd7..18880f0 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -583,6 +583,60 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { }, false, }, + { + "code flow dev mode has redirect globs regular ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs wildcard ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs double star ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs double star ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs IPv6 ok", + args{ + "http://[::1]:80/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://\\[::1\\]:80/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs bad pattern", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/\\"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/op/client.go b/pkg/op/client.go index 04ef3c7..0574afa 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -63,6 +63,7 @@ type Client interface { // such as DevMode for the client being enabled. // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type HasRedirectGlobs interface { + Client RedirectURIGlobs() []string PostLogoutRedirectURIGlobs() []string } diff --git a/pkg/op/config.go b/pkg/op/config.go index c383480..9fec7cc 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -54,7 +54,24 @@ type Configuration interface { type IssuerFromRequest func(r *http.Request) string func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { - return issuerFromForwardedOrHost(path, false) + return issuerFromForwardedOrHost(path, new(issuerConfig)) +} + +type IssuerFromOption func(c *issuerConfig) + +// WithIssuerFromCustomHeaders can be used to customize the header names used. +// The same rules apply where the first successful host is returned. +func WithIssuerFromCustomHeaders(headers ...string) IssuerFromOption { + return func(c *issuerConfig) { + for i, h := range headers { + headers[i] = http.CanonicalHeaderKey(h) + } + c.headers = headers + } +} + +type issuerConfig struct { + headers []string } // IssuerFromForwardedOrHost tries to establish the Issuer based @@ -64,11 +81,18 @@ func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { // If the Forwarded header is not present, no host field is found, // or there is a parser error the Request Host will be used as a fallback. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded -func IssuerFromForwardedOrHost(path string) func(bool) (IssuerFromRequest, error) { - return issuerFromForwardedOrHost(path, true) +func IssuerFromForwardedOrHost(path string, opts ...IssuerFromOption) func(bool) (IssuerFromRequest, error) { + c := &issuerConfig{ + headers: []string{http.CanonicalHeaderKey("forwarded")}, + } + for _, opt := range opts { + opt(c) + } + + return issuerFromForwardedOrHost(path, c) } -func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (IssuerFromRequest, error) { +func issuerFromForwardedOrHost(path string, c *issuerConfig) func(bool) (IssuerFromRequest, error) { return func(allowInsecure bool) (IssuerFromRequest, error) { issuerPath, err := url.Parse(path) if err != nil { @@ -78,26 +102,26 @@ func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (Iss return nil, err } return func(r *http.Request) string { - if parseForwarded { - if host, ok := hostFromForwarded(r); ok { - return dynamicIssuer(host, path, allowInsecure) - } + if host, ok := hostFromForwarded(r, c.headers); ok { + return dynamicIssuer(host, path, allowInsecure) } return dynamicIssuer(r.Host, path, allowInsecure) }, nil } } -func hostFromForwarded(r *http.Request) (host string, ok bool) { - fwd, err := httpforwarded.ParseFromRequest(r) - if err != nil { - log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch - return "", false +func hostFromForwarded(r *http.Request, headers []string) (host string, ok bool) { + for _, header := range headers { + hosts, err := httpforwarded.ParseParameter("host", r.Header[header]) + if err != nil { + log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch + continue + } + if len(hosts) > 0 { + return hosts[0], true + } } - if fwd == nil || len(fwd["host"]) == 0 { - return "", false - } - return fwd["host"][0], true + return "", false } func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) { diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index dcafc3a..d739348 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -1,6 +1,7 @@ package op import ( + "net/http" "net/http/httptest" "net/url" "testing" @@ -264,9 +265,10 @@ func TestIssuerFromHost(t *testing.T) { func TestIssuerFromForwardedOrHost(t *testing.T) { type args struct { - path string - target string - forwarded []string + path string + opts []IssuerFromOption + target string + header map[string][]string } type res struct { issuer string @@ -279,9 +281,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { { "header parse error", args{ - path: "/custom/", - target: "https://issuer.com", - forwarded: []string{"~~~"}, + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": {"~~~~"}}, }, res{ issuer: "https://issuer.com/custom/", @@ -303,9 +305,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;proto=https`, - }, + }}, }, res{ issuer: "https://issuer.com/custom/", @@ -316,9 +318,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https`, - }, + }}, }, res{ issuer: "https://first.com/custom/", @@ -329,9 +331,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, - }, + }}, }, res{ issuer: "https://first.com/custom/", @@ -342,23 +344,45 @@ func TestIssuerFromForwardedOrHost(t *testing.T) { args{ path: "/custom/", target: "https://issuer.com", - forwarded: []string{ + header: map[string][]string{"Forwarded": { `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, `by=identifier;for=identifier;host=third.com;proto=https`, - }, + }}, }, res{ issuer: "https://first.com/custom/", }, }, + { + "custom header first", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{ + "Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, + `by=identifier;for=identifier;host=third.com;proto=https`, + }, + "X-Custom-Forwarded": { + `by=identifier;for=identifier;host=custom.com;proto=https,host=custom2.com`, + }, + }, + opts: []IssuerFromOption{ + WithIssuerFromCustomHeaders("x-custom-forwarded"), + }, + }, + res{ + issuer: "https://custom.com/custom/", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - issuer, err := IssuerFromForwardedOrHost(tt.args.path)(false) + issuer, err := IssuerFromForwardedOrHost(tt.args.path, tt.args.opts...)(false) require.NoError(t, err) req := httptest.NewRequest("", tt.args.target, nil) - if tt.args.forwarded != nil { - req.Header["Forwarded"] = tt.args.forwarded + for k, v := range tt.args.header { + req.Header[http.CanonicalHeaderKey(k)] = v } assert.Equal(t, tt.res.issuer, issuer(req)) }) diff --git a/pkg/op/device.go b/pkg/op/device.go index 813c3f5..1b86d04 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -14,6 +14,7 @@ import ( httphelper "github.com/zitadel/oidc/v3/pkg/http" "github.com/zitadel/oidc/v3/pkg/oidc" + strs "github.com/zitadel/oidc/v3/pkg/strings" ) type DeviceAuthorizationConfig struct { @@ -185,24 +186,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { return buf.String(), nil } -type deviceAccessTokenRequest struct { - subject string - audience []string - scopes []string -} - -func (r *deviceAccessTokenRequest) GetSubject() string { - return r.subject -} - -func (r *deviceAccessTokenRequest) GetAudience() []string { - return r.audience -} - -func (r *deviceAccessTokenRequest) GetScopes() []string { - return r.scopes -} - func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { ctx, span := tracer.Start(r.Context(), "DeviceAccessToken") defer span.End() @@ -229,7 +212,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) + tokenRequest, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) if err != nil { return err } @@ -243,11 +226,6 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang WithDescription("confidential client requires authentication") } - tokenRequest := &deviceAccessTokenRequest{ - subject: state.Subject, - audience: []string{clientID}, - scopes: state.Scopes, - } resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) if err != nil { return err @@ -265,6 +243,50 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc. return req, nil } +// DeviceAuthorizationState describes the current state of +// the device authorization flow. +// It implements the [IDTokenRequest] interface. +type DeviceAuthorizationState struct { + ClientID string + Audience []string + Scopes []string + Expires time.Time // The time after we consider the authorization request timed-out + Done bool // The user authenticated and approved the authorization request + Denied bool // The user authenticated and denied the authorization request + + // The following fields are populated after Done == true + Subject string + AMR []string + AuthTime time.Time +} + +func (r *DeviceAuthorizationState) GetAMR() []string { + return r.AMR +} + +func (r *DeviceAuthorizationState) GetAudience() []string { + if !strs.Contains(r.Audience, r.ClientID) { + r.Audience = append(r.Audience, r.ClientID) + } + return r.Audience +} + +func (r *DeviceAuthorizationState) GetAuthTime() time.Time { + return r.AuthTime +} + +func (r *DeviceAuthorizationState) GetClientID() string { + return r.ClientID +} + +func (r *DeviceAuthorizationState) GetScopes() []string { + return r.Scopes +} + +func (r *DeviceAuthorizationState) GetSubject() string { + return r.Subject +} + func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { @@ -291,15 +313,32 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str } func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { + /* TODO(v4): + Change the TokenRequest argument type to *DeviceAuthorizationState. + Breaking change that can not be done for v3. + */ + ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse") + defer span.End() + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err } - return &oidc.AccessTokenResponse{ + response := &oidc.AccessTokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), - }, nil + } + + // TODO(v4): remove type assertion + if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && strs.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) { + response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client) + if err != nil { + return nil, err + } + } + + return response, nil } diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index f5452f9..570b943 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -453,3 +453,96 @@ func TestCheckDeviceAuthorizationState(t *testing.T) { }) } } + +func TestCreateDeviceTokenResponse(t *testing.T) { + tests := []struct { + name string + tokenRequest op.TokenRequest + wantAccessToken bool + wantRefreshToken bool + wantIDToken bool + wantErr bool + }{ + { + name: "access token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + }, + wantAccessToken: true, + }, + { + name: "access and refresh tokens", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess}, + }, + wantAccessToken: true, + wantRefreshToken: true, + }, + { + name: "access and id token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantIDToken: true, + }, + { + name: "access, refresh and id token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantRefreshToken: true, + wantIDToken: true, + }, + { + name: "id token creation error", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "foobar", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native") + require.NoError(t, err) + + got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.InDelta(t, 300, got.ExpiresIn, 2) + if tt.wantAccessToken { + assert.NotEmpty(t, got.AccessToken, "access token") + } + if tt.wantRefreshToken { + assert.NotEmpty(t, got.RefreshToken, "refresh token") + } + if tt.wantIDToken { + assert.NotEmpty(t, got.IDToken, "id token") + } + }) + } +} diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 8251261..6af1674 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -213,32 +213,12 @@ func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod { } func SupportedClaims(c Configuration) []string { - return []string{ // TODO: config - "sub", - "aud", - "exp", - "iat", - "iss", - "auth_time", - "nonce", - "acr", - "amr", - "c_hash", - "at_hash", - "act", - "scopes", - "client_id", - "azp", - "preferred_username", - "name", - "family_name", - "given_name", - "locale", - "email", - "email_verified", - "phone_number", - "phone_number_verified", + provider, ok := c.(*Provider) + if ok && provider.config.SupportedClaims != nil { + return provider.config.SupportedClaims } + + return DefaultSupportedClaims } func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { diff --git a/pkg/op/error.go b/pkg/op/error.go index 0cac14b..e4580f6 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -157,13 +157,29 @@ func (e StatusError) Is(err error) bool { e.statusCode == target.statusCode } -// WriteError asserts for a StatusError containing an [oidc.Error]. -// If no StatusError is found, the status code will default to [http.StatusBadRequest]. -// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError]. +// WriteError asserts for a [StatusError] containing an [oidc.Error]. +// If no `StatusError` is found, the status code will default to [http.StatusBadRequest]. +// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError]. +// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`, +// the status code will be set to [http.StatusInternalServerError] func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { - statusError := AsStatusError(err, http.StatusBadRequest) - e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()) - - logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) - httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode) + var statusError StatusError + if errors.As(err, &statusError) { + writeError(w, r, + oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()), + statusError.statusCode, logger, + ) + return + } + statusCode := http.StatusBadRequest + e := oidc.DefaultToServerError(err, err.Error()) + if e.ErrorType == oidc.ServerError { + statusCode = http.StatusInternalServerError + } + writeError(w, r, e, statusCode, logger) +} + +func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) { + logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode) + httphelper.MarshalJSONWithStatus(w, err, statusCode) } diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go index 689ee5a..50a9cbf 100644 --- a/pkg/op/error_test.go +++ b/pkg/op/error_test.go @@ -579,7 +579,7 @@ func TestWriteError(t *testing.T) { { name: "not a status or oidc error", err: io.ErrClosedPipe, - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{ "error":"server_error", "error_description":"io: read/write on closed pipe" @@ -592,6 +592,7 @@ func TestWriteError(t *testing.T) { "parent":"io: read/write on closed pipe", "type":"server_error" }, + "status_code":500, "time":"not" }`, }, @@ -611,6 +612,7 @@ func TestWriteError(t *testing.T) { "parent":"io: read/write on closed pipe", "type":"server_error" }, + "status_code":500, "time":"not" }`, }, @@ -629,6 +631,7 @@ func TestWriteError(t *testing.T) { "description":"oops", "type":"invalid_request" }, + "status_code":400, "time":"not" }`, }, @@ -650,6 +653,7 @@ func TestWriteError(t *testing.T) { "description":"oops", "type":"unauthorized_client" }, + "status_code":401, "time":"not" }`, }, diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go index 590356c..e5cab3e 100644 --- a/pkg/op/mock/generate.go +++ b/pkg/op/mock/generate.go @@ -4,6 +4,7 @@ package mock //go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage //go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer //go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client +//go:generate mockgen -package mock -destination ./glob.mock.go github.com/zitadel/oidc/v3/pkg/op HasRedirectGlobs //go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration //go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage //go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key diff --git a/pkg/op/mock/glob.go b/pkg/op/mock/glob.go new file mode 100644 index 0000000..cade476 --- /dev/null +++ b/pkg/op/mock/glob.go @@ -0,0 +1,24 @@ +package mock + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" +) + +func NewHasRedirectGlobs(t *testing.T) op.HasRedirectGlobs { + return NewMockHasRedirectGlobs(gomock.NewController(t)) +} + +func NewHasRedirectGlobsWithConfig(t *testing.T, uri []string, appType op.ApplicationType, responseTypes []oidc.ResponseType, devMode bool) op.HasRedirectGlobs { + c := NewHasRedirectGlobs(t) + m := c.(*MockHasRedirectGlobs) + m.EXPECT().RedirectURIs().AnyTimes().Return(uri) + m.EXPECT().RedirectURIGlobs().AnyTimes().Return(uri) + m.EXPECT().ApplicationType().AnyTimes().Return(appType) + m.EXPECT().ResponseTypes().AnyTimes().Return(responseTypes) + m.EXPECT().DevMode().AnyTimes().Return(devMode) + return c +} diff --git a/pkg/op/mock/glob.mock.go b/pkg/op/mock/glob.mock.go new file mode 100644 index 0000000..cf9996e --- /dev/null +++ b/pkg/op/mock/glob.mock.go @@ -0,0 +1,289 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: HasRedirectGlobs) + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + oidc "github.com/zitadel/oidc/v3/pkg/oidc" + op "github.com/zitadel/oidc/v3/pkg/op" +) + +// MockHasRedirectGlobs is a mock of HasRedirectGlobs interface. +type MockHasRedirectGlobs struct { + ctrl *gomock.Controller + recorder *MockHasRedirectGlobsMockRecorder +} + +// MockHasRedirectGlobsMockRecorder is the mock recorder for MockHasRedirectGlobs. +type MockHasRedirectGlobsMockRecorder struct { + mock *MockHasRedirectGlobs +} + +// NewMockHasRedirectGlobs creates a new mock instance. +func NewMockHasRedirectGlobs(ctrl *gomock.Controller) *MockHasRedirectGlobs { + mock := &MockHasRedirectGlobs{ctrl: ctrl} + mock.recorder = &MockHasRedirectGlobsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHasRedirectGlobs) EXPECT() *MockHasRedirectGlobsMockRecorder { + return m.recorder +} + +// AccessTokenType mocks base method. +func (m *MockHasRedirectGlobs) AccessTokenType() op.AccessTokenType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenType") + ret0, _ := ret[0].(op.AccessTokenType) + return ret0 +} + +// AccessTokenType indicates an expected call of AccessTokenType. +func (mr *MockHasRedirectGlobsMockRecorder) AccessTokenType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AccessTokenType)) +} + +// ApplicationType mocks base method. +func (m *MockHasRedirectGlobs) ApplicationType() op.ApplicationType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplicationType") + ret0, _ := ret[0].(op.ApplicationType) + return ret0 +} + +// ApplicationType indicates an expected call of ApplicationType. +func (mr *MockHasRedirectGlobsMockRecorder) ApplicationType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ApplicationType)) +} + +// AuthMethod mocks base method. +func (m *MockHasRedirectGlobs) AuthMethod() oidc.AuthMethod { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthMethod") + ret0, _ := ret[0].(oidc.AuthMethod) + return ret0 +} + +// AuthMethod indicates an expected call of AuthMethod. +func (mr *MockHasRedirectGlobsMockRecorder) AuthMethod() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethod", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AuthMethod)) +} + +// ClockSkew mocks base method. +func (m *MockHasRedirectGlobs) ClockSkew() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClockSkew") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// ClockSkew indicates an expected call of ClockSkew. +func (mr *MockHasRedirectGlobsMockRecorder) ClockSkew() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClockSkew", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ClockSkew)) +} + +// DevMode mocks base method. +func (m *MockHasRedirectGlobs) DevMode() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DevMode") + ret0, _ := ret[0].(bool) + return ret0 +} + +// DevMode indicates an expected call of DevMode. +func (mr *MockHasRedirectGlobsMockRecorder) DevMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DevMode", reflect.TypeOf((*MockHasRedirectGlobs)(nil).DevMode)) +} + +// GetID mocks base method. +func (m *MockHasRedirectGlobs) GetID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID. +func (mr *MockHasRedirectGlobsMockRecorder) GetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GetID)) +} + +// GrantTypes mocks base method. +func (m *MockHasRedirectGlobs) GrantTypes() []oidc.GrantType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypes") + ret0, _ := ret[0].([]oidc.GrantType) + return ret0 +} + +// GrantTypes indicates an expected call of GrantTypes. +func (mr *MockHasRedirectGlobsMockRecorder) GrantTypes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GrantTypes)) +} + +// IDTokenLifetime mocks base method. +func (m *MockHasRedirectGlobs) IDTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// IDTokenLifetime indicates an expected call of IDTokenLifetime. +func (mr *MockHasRedirectGlobsMockRecorder) IDTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenLifetime)) +} + +// IDTokenUserinfoClaimsAssertion mocks base method. +func (m *MockHasRedirectGlobs) IDTokenUserinfoClaimsAssertion() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenUserinfoClaimsAssertion") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IDTokenUserinfoClaimsAssertion indicates an expected call of IDTokenUserinfoClaimsAssertion. +func (mr *MockHasRedirectGlobsMockRecorder) IDTokenUserinfoClaimsAssertion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenUserinfoClaimsAssertion", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenUserinfoClaimsAssertion)) +} + +// IsScopeAllowed mocks base method. +func (m *MockHasRedirectGlobs) IsScopeAllowed(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsScopeAllowed", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsScopeAllowed indicates an expected call of IsScopeAllowed. +func (mr *MockHasRedirectGlobsMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IsScopeAllowed), arg0) +} + +// LoginURL mocks base method. +func (m *MockHasRedirectGlobs) LoginURL(arg0 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoginURL", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// LoginURL indicates an expected call of LoginURL. +func (mr *MockHasRedirectGlobsMockRecorder) LoginURL(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockHasRedirectGlobs)(nil).LoginURL), arg0) +} + +// PostLogoutRedirectURIGlobs mocks base method. +func (m *MockHasRedirectGlobs) PostLogoutRedirectURIGlobs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostLogoutRedirectURIGlobs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// PostLogoutRedirectURIGlobs indicates an expected call of PostLogoutRedirectURIGlobs. +func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIGlobs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIGlobs)) +} + +// PostLogoutRedirectURIs mocks base method. +func (m *MockHasRedirectGlobs) PostLogoutRedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostLogoutRedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs. +func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIs)) +} + +// RedirectURIGlobs mocks base method. +func (m *MockHasRedirectGlobs) RedirectURIGlobs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RedirectURIGlobs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RedirectURIGlobs indicates an expected call of RedirectURIGlobs. +func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIGlobs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIGlobs)) +} + +// RedirectURIs mocks base method. +func (m *MockHasRedirectGlobs) RedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RedirectURIs indicates an expected call of RedirectURIs. +func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIs)) +} + +// ResponseTypes mocks base method. +func (m *MockHasRedirectGlobs) ResponseTypes() []oidc.ResponseType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResponseTypes") + ret0, _ := ret[0].([]oidc.ResponseType) + return ret0 +} + +// ResponseTypes indicates an expected call of ResponseTypes. +func (mr *MockHasRedirectGlobsMockRecorder) ResponseTypes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ResponseTypes)) +} + +// RestrictAdditionalAccessTokenScopes mocks base method. +func (m *MockHasRedirectGlobs) RestrictAdditionalAccessTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes. +func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalAccessTokenScopes)) +} + +// RestrictAdditionalIdTokenScopes mocks base method. +func (m *MockHasRedirectGlobs) RestrictAdditionalIdTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes. +func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalIdTokenScopes)) +} diff --git a/pkg/op/op.go b/pkg/op/op.go index 2bd130b..14c5356 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -45,6 +45,33 @@ var ( DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), } + DefaultSupportedClaims = []string{ + "sub", + "aud", + "exp", + "iat", + "iss", + "auth_time", + "nonce", + "acr", + "amr", + "c_hash", + "at_hash", + "act", + "scopes", + "client_id", + "azp", + "preferred_username", + "name", + "family_name", + "given_name", + "locale", + "email", + "email_verified", + "phone_number", + "phone_number_verified", + } + defaultCORSOptions = cors.Options{ AllowCredentials: true, AllowedHeaders: []string{ @@ -97,9 +124,19 @@ type OpenIDProvider interface { type HttpInterceptor func(http.Handler) http.Handler +type corsOptioner interface { + CORSOptions() *cors.Options +} + func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { router := chi.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) + if co, ok := o.(corsOptioner); ok { + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } + } else { + router.Use(cors.New(defaultCORSOptions).Handler) + } router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) @@ -136,6 +173,7 @@ type Config struct { GrantTypeRefreshToken bool RequestObjectSupported bool SupportedUILocales []language.Tag + SupportedClaims []string DeviceAuthorization DeviceAuthorizationConfig } @@ -173,28 +211,62 @@ type Endpoints struct { // Successful logins should mark the request as authorized and redirect back to to // op.AuthCallbackURL(provider) which is probably /callback. On the redirect back // to the AuthCallbackURL, the request id should be passed as the "id" parameter. +// +// Deprecated: use [NewProvider] with an issuer function direct. func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { - return newProvider(config, storage, StaticIssuer(issuer), opOpts...) + return NewProvider(config, storage, StaticIssuer(issuer), opOpts...) } // NewForwardedOpenIDProvider tries to establishes the issuer from the request Host. +// +// Deprecated: use [NewProvider] with an issuer function direct. func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { - return newProvider(config, storage, IssuerFromHost(path), opOpts...) + return NewProvider(config, storage, IssuerFromHost(path), opOpts...) } // NewForwardedOpenIDProvider tries to establish the Issuer from a Forwarded request header, if it is set. // See [IssuerFromForwardedOrHost] for details. +// +// Deprecated: use [NewProvider] with an issuer function direct. func NewForwardedOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { - return newProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...) + return NewProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...) } -func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) { +// NewProvider creates a provider with a router on it's embedded http.Handler. +// Issuer is a function that must return the issuer on every request. +// Typically [StaticIssuer], [IssuerFromHost] or [IssuerFromForwardedOrHost] can be used. +// +// The router handles a suite of endpoints (some paths can be overridden): +// +// /healthz +// /ready +// /.well-known/openid-configuration +// /oauth/token +// /oauth/introspect +// /callback +// /authorize +// /userinfo +// /revoke +// /end_session +// /keys +// /device_authorization +// +// This does not include login. Login is handled with a redirect that includes the +// request ID. The redirect for logins is specified per-client by Client.LoginURL(). +// Successful logins should mark the request as authorized and redirect back to to +// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back +// to the AuthCallbackURL, the request id should be passed as the "id" parameter. +func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) { + keySet := &OpenIDKeySet{storage} o := &Provider{ - config: config, - storage: storage, - endpoints: DefaultEndpoints, - timer: make(<-chan time.Time), - logger: slog.Default(), + config: config, + storage: storage, + accessTokenKeySet: keySet, + idTokenHinKeySet: keySet, + endpoints: DefaultEndpoints, + timer: make(<-chan time.Time), + corsOpts: &defaultCORSOptions, + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -207,19 +279,11 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR if err != nil { return nil, err } - o.Handler = CreateRouter(o, o.interceptors...) - o.decoder = schema.NewDecoder() o.decoder.IgnoreUnknownKeys(true) - o.encoder = oidc.NewEncoder() - o.crypto = NewAESCrypto(config.CryptoKey) - - // Avoid potential race conditions by calling these early - _ = o.openIDKeySet() // sets keySet - return o, nil } @@ -230,7 +294,8 @@ type Provider struct { insecure bool endpoints *Endpoints storage Storage - keySet *openIDKeySet + accessTokenKeySet oidc.KeySet + idTokenHinKeySet oidc.KeySet crypto Crypto decoder *schema.Decoder encoder *schema.Encoder @@ -238,6 +303,7 @@ type Provider struct { timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + corsOpts *cors.Options logger *slog.Logger } @@ -365,7 +431,7 @@ func (o *Provider) Encoder() httphelper.Encoder { } func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier { - return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) + return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.idTokenHinKeySet, o.idTokenHintVerifierOpts...) } func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { @@ -373,14 +439,7 @@ func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { } func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier { - return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) -} - -func (o *Provider) openIDKeySet() oidc.KeySet { - if o.keySet == nil { - o.keySet = &openIDKeySet{o.Storage()} - } - return o.keySet + return NewAccessTokenVerifier(IssuerFromContext(ctx), o.accessTokenKeySet, o.accessTokenVerifierOpts...) } func (o *Provider) Crypto() Crypto { @@ -397,6 +456,10 @@ func (o *Provider) Probes() []ProbesFn { } } +func (o *Provider) CORSOptions() *cors.Options { + return o.corsOpts +} + func (o *Provider) Logger() *slog.Logger { return o.logger } @@ -406,13 +469,13 @@ func (o *Provider) HttpHandler() http.Handler { return o } -type openIDKeySet struct { +type OpenIDKeySet struct { Storage } // VerifySignature implements the oidc.KeySet interface // providing an implementation for the keys stored in the OP Storage interface -func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { +func (o *OpenIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { keySet, err := o.Storage.KeySet(ctx) if err != nil { return nil, fmt.Errorf("error fetching keys: %w", err) @@ -543,6 +606,15 @@ func WithHttpInterceptors(interceptors ...HttpInterceptor) Option { } } +// WithAccessTokenKeySet allows passing a KeySet with public keys for Access Token verification. +// The default KeySet uses the [Storage] interface +func WithAccessTokenKeySet(keySet oidc.KeySet) Option { + return func(o *Provider) error { + o.accessTokenKeySet = keySet + return nil + } +} + func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option { return func(o *Provider) error { o.accessTokenVerifierOpts = opts @@ -550,6 +622,15 @@ func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option { } } +// WithIDTokenHintKeySet allows passing a KeySet with public keys for ID Token Hint verification. +// The default KeySet uses the [Storage] interface. +func WithIDTokenHintKeySet(keySet oidc.KeySet) Option { + return func(o *Provider) error { + o.idTokenHinKeySet = keySet + return nil + } +} + func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { return func(o *Provider) error { o.idTokenHintVerifierOpts = opts @@ -557,6 +638,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithCORSOptions(opts *cors.Options) Option { + return func(o *Provider) error { + o.corsOpts = opts + return nil + } +} + // WithLogger lets a logger other than slog.Default(). // // EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20 @@ -573,6 +661,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle for i := len(interceptors) - 1; i >= 0; i-- { handler = interceptors[i](handler) } - return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler)) + return issuerInterceptor.Handler(handler) } } diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index 062fcfe..b2a758c 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -30,6 +30,7 @@ var ( AuthMethodPrivateKeyJWT: true, GrantTypeRefreshToken: true, RequestObjectSupported: true, + SupportedClaims: op.DefaultSupportedClaims, SupportedUILocales: []language.Tag{language.English}, DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, @@ -57,8 +58,12 @@ func init() { } func newTestProvider(config *op.Config) op.OpenIDProvider { - provider, err := op.NewOpenIDProvider(testIssuer, config, - storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), + storage := storage.NewStorage(storage.NewUserStore(testIssuer)) + keySet := &op.OpenIDKeySet{storage} + provider, err := op.NewOpenIDProvider(testIssuer, config, storage, + op.WithAllowInsecure(), + op.WithAccessTokenKeySet(keySet), + op.WithIDTokenHintKeySet(keySet), ) if err != nil { panic(err) diff --git a/pkg/op/server.go b/pkg/op/server.go index a9cdcf5..829618c 100644 --- a/pkg/op/server.go +++ b/pkg/op/server.go @@ -127,7 +127,7 @@ type Server interface { // Introspect handles the OAuth 2.0 Token Introspection endpoint. // https://datatracker.ietf.org/doc/html/rfc7662 // The recommended Response Data type is [oidc.IntrospectionResponse]. - Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) + Introspect(context.Context, *Request[IntrospectionRequest]) (*Response, error) // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo @@ -329,7 +329,7 @@ func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oid return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) } -func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { +func (UnimplementedServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) { return nil, unimplementedError(r) } diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go index 96ee7a5..2220e44 100644 --- a/pkg/op/server_http.go +++ b/pkg/op/server_http.go @@ -25,9 +25,11 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) decoder.IgnoreUnknownKeys(true) ws := &webServer{ + router: chi.NewRouter(), server: server, endpoints: endpoints, decoder: decoder, + corsOpts: &defaultCORSOptions, logger: slog.Default(), } @@ -36,6 +38,10 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) } ws.createRouter() + ws.handler = ws.router + if ws.corsOpts != nil { + ws.handler = cors.New(*ws.corsOpts).Handler(ws.router) + } return ws } @@ -45,7 +51,14 @@ type ServerOption func(s *webServer) // the Server's router. func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { return func(s *webServer) { - s.middleware = m + s.router.Use(m...) + } +} + +// WithSetRouter allows customization or the Server's router. +func WithSetRouter(set func(chi.Router)) ServerOption { + return func(s *webServer) { + set(s.router) } } @@ -57,6 +70,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption { } } +// WithServerCORSOptions sets the CORS policy for the Server's router. +func WithServerCORSOptions(opts *cors.Options) ServerOption { + return func(s *webServer) { + s.corsOpts = opts + } +} + // WithFallbackLogger overrides the fallback logger, which // is used when no logger was found in the context. // Defaults to [slog.Default]. @@ -67,12 +87,17 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption { } type webServer struct { - http.Handler - server Server - middleware []func(http.Handler) http.Handler - endpoints Endpoints - decoder httphelper.Decoder - logger *slog.Logger + server Server + router *chi.Mux + handler http.Handler + endpoints Endpoints + decoder httphelper.Decoder + corsOpts *cors.Options + logger *slog.Logger +} + +func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) } func (s *webServer) getLogger(ctx context.Context) *slog.Logger { @@ -83,27 +108,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger { } func (s *webServer) createRouter() { - router := chi.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) - router.Use(s.middleware...) - router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) - router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) - router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) - s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler) - s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) - s.endpointRoute(router, s.endpoints.Token, s.tokensHandler) - s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler)) - s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler) - s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler)) - s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler) - s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) - s.Handler = router + s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(s.endpoints.Token, s.tokensHandler) + s.endpointRoute(s.endpoints.Introspection, s.introspectionHandler) + s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) } -func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) { +func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) { if e != nil { - router.HandleFunc(e.Relative(), hf) + s.router.HandleFunc(e.Relative(), hf) s.logger.Info("registered route", "endpoint", e.Relative()) } } @@ -128,7 +149,21 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { } func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { - if err = r.ParseForm(); err != nil { + cc, err := s.parseClientCredentials(r) + if err != nil { + return nil, err + } + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: cc, + }) +} + +func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredentials, err error) { + if err := r.ParseForm(); err != nil { return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) } cc := new(ClientCredentials) @@ -152,13 +187,7 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) } - return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ - Method: r.Method, - URL: r.URL, - Header: r.Header, - Form: r.Form, - Data: cc, - }) + return cc, nil } func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { @@ -370,8 +399,13 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c resp.writeOut(w) } -func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) { - if client.AuthMethod() == oidc.AuthMethodNone { +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) { + cc, err := s.parseClientCredentials(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if cc.ClientSecret == "" && cc.ClientAssertion == "" { WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) return } @@ -384,7 +418,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) return } - resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client)) + resp, err := s.server.Introspect(r.Context(), newRequest(r, &IntrospectionRequest{cc, request})) if err != nil { WriteError(w, r, err, s.getLogger(r.Context())) return diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go index c7767d2..c50e989 100644 --- a/pkg/op/server_http_routes_test.go +++ b/pkg/op/server_http_routes_test.go @@ -32,7 +32,7 @@ func jwtProfile() (string, error) { } func TestServerRoutes(t *testing.T) { - server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints) + server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints)) storage := testProvider.Storage().(routesTestStorage) ctx := op.ContextWithIssuer(context.Background(), testIssuer) diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go index 86fe7ed..6cb268f 100644 --- a/pkg/op/server_http_test.go +++ b/pkg/op/server_http_test.go @@ -365,14 +365,14 @@ func Test_webServer_authorizeHandler(t *testing.T) { }, }, { - name: "authorize error", + name: "server error", fields: fields{ server: &requestVerifier{}, decoder: testDecoder, }, r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), want: webServerResult{ - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{"error":"server_error"}`, }, }, @@ -1001,14 +1001,12 @@ func Test_webServer_introspectionHandler(t *testing.T) { tests := []struct { name string decoder httphelper.Decoder - client Client r *http.Request want webServerResult }{ { name: "decoder error", decoder: schema.NewDecoder(), - client: newClient(clientTypeUserAgent), r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), want: webServerResult{ wantStatus: http.StatusBadRequest, @@ -1018,8 +1016,7 @@ func Test_webServer_introspectionHandler(t *testing.T) { { name: "public client", decoder: testDecoder, - client: newClient(clientTypeNative), - r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123")), want: webServerResult{ wantStatus: http.StatusBadRequest, wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, @@ -1028,8 +1025,7 @@ func Test_webServer_introspectionHandler(t *testing.T) { { name: "token missing", decoder: testDecoder, - client: newClient(clientTypeWeb), - r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET")), want: webServerResult{ wantStatus: http.StatusBadRequest, wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, @@ -1038,8 +1034,7 @@ func Test_webServer_introspectionHandler(t *testing.T) { { name: "unimplemented Introspect called", decoder: testDecoder, - client: newClient(clientTypeWeb), - r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET&token=xxx")), want: webServerResult{ wantStatus: UnimplementedStatusCode, wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, @@ -1053,7 +1048,7 @@ func Test_webServer_introspectionHandler(t *testing.T) { decoder: tt.decoder, logger: slog.Default(), } - runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want) + runWebServerTest(t, s.introspectionHandler, tt.r, tt.want) }) } } @@ -1242,7 +1237,7 @@ func Test_webServer_simpleHandler(t *testing.T) { }, r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), want: webServerResult{ - wantStatus: http.StatusBadRequest, + wantStatus: http.StatusInternalServerError, wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, }, }, diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go index 5907e28..f99d15d 100644 --- a/pkg/op/server_legacy.go +++ b/pkg/op/server_legacy.go @@ -10,37 +10,77 @@ import ( "github.com/zitadel/oidc/v3/pkg/oidc" ) -// LegacyServer is an implementation of [Server[] that -// simply wraps a [OpenIDProvider]. +// ExtendedLegacyServer allows embedding [LegacyServer] in a struct, +// so that its methods can be individually overridden. +// +// EXPERIMENTAL: may change until v4 +type ExtendedLegacyServer interface { + Server + Provider() OpenIDProvider + Endpoints() Endpoints + AuthCallbackURL() func(context.Context, string) string +} + +// RegisterLegacyServer registers a [LegacyServer] or an extension thereof. +// It takes care of registering the IssuerFromRequest middleware +// and Authorization Callback Routes. +// Neither are part of the bare [Server] interface. +// +// EXPERIMENTAL: may change until v4 +func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler { + provider := s.Provider() + options = append(options, + WithHTTPMiddleware(intercept(provider.IssuerFromRequest)), + WithSetRouter(func(r chi.Router) { + r.HandleFunc(s.Endpoints().Authorization.Relative()+authCallbackPathSuffix, authorizeCallbackHandler(provider)) + }), + ) + return RegisterServer(s, s.Endpoints(), options...) +} + +// LegacyServer is an implementation of [Server] that +// simply wraps an [OpenIDProvider]. // It can be used to transition from the former Provider/Storage // interfaces to the new Server interface. +// +// EXPERIMENTAL: may change until v4 type LegacyServer struct { UnimplementedServer provider OpenIDProvider endpoints Endpoints } -// NewLegacyServer wraps provider in a `Server` and returns a handler which is -// the Server's router. +// NewLegacyServer wraps provider in a `Server` implementation // // Only non-nil endpoints will be registered on the router. // Nil endpoints are disabled. // -// The passed endpoints is also set to the provider, -// to be consistent with the discovery config. +// The passed endpoints is also used for the discovery config, +// and endpoints already set to the provider are ignored. // Any `With*Endpoint()` option used on the provider is // therefore ineffective. -func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler { - server := RegisterServer(&LegacyServer{ +// +// EXPERIMENTAL: may change until v4 +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer { + return &LegacyServer{ provider: provider, endpoints: endpoints, - }, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest))) + } +} - router := chi.NewRouter() - router.Mount("/", server) - router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider)) +func (s *LegacyServer) Provider() OpenIDProvider { + return s.provider +} - return router +func (s *LegacyServer) Endpoints() Endpoints { + return s.endpoints +} + +// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login +func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string { + return func(ctx context.Context, requestID string) string { + return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID + } } func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { @@ -51,7 +91,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon for _, probe := range s.provider.Probes() { // shouldn't we run probes in Go routines? if err := probe(ctx); err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } } return NewResponse(Status{Status: "ok"}), nil @@ -66,7 +106,7 @@ func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Re func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { keys, err := s.provider.Storage().KeySet(ctx) if err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } return NewResponse(jsonWebKeySet(keys)), nil } @@ -87,7 +127,7 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au } } if r.Data.ClientID == "" { - return nil, ErrAuthReqMissingClientID + return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error()) } client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) if err != nil { @@ -115,7 +155,7 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) if err != nil { - return nil, NewStatusError(err, http.StatusInternalServerError) + return nil, AsStatusError(err, http.StatusInternalServerError) } return NewResponse(response), nil } @@ -165,11 +205,14 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A if err != nil { return nil, err } - if r.Client.AuthMethod() == oidc.AuthMethodNone { + if r.Client.AuthMethod() == oidc.AuthMethodNone || r.Data.CodeVerifier != "" { if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { return nil, err } } + if r.Data.RedirectURI != authReq.GetRedirectURI() { + return nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond") + } resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") if err != nil { return nil, err @@ -205,7 +248,7 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil } tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) if err != nil { - return nil, err + return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid") } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) @@ -251,7 +294,7 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR } func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { - if !s.provider.GrantTypeClientCredentialsSupported() { + if !s.provider.GrantTypeDeviceCodeSupported() { return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) } // use a limited context timeout shorter as the default @@ -259,15 +302,10 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() - state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) if err != nil { return nil, err } - tokenRequest := &deviceAccessTokenRequest{ - subject: state.Subject, - audience: []string{r.Client.GetID()}, - scopes: state.Scopes, - } resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) if err != nil { return nil, err @@ -275,13 +313,30 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De return NewResponse(resp), nil } -func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) { +func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) { + if cc.ClientAssertion != "" { + if jp, ok := s.provider.(ClientJWTProfile); ok { + return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp) + } + return "", oidc.ErrInvalidClient().WithDescription("client_assertion not supported") + } + if err := s.provider.Storage().AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err) + } + return cc.ClientID, nil +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) { + clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials) + if err != nil { + return nil, err + } response := new(oidc.IntrospectionResponse) tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) if !ok { return NewResponse(response), nil } - err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID()) + err = s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID) if err != nil { return NewResponse(response), nil } @@ -336,9 +391,14 @@ func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessio if err != nil { return nil, err } - err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + redirect := session.RedirectURI + if fromRequest, ok := s.provider.Storage().(CanTerminateSessionFromRequest); ok { + redirect, err = fromRequest.TerminateSessionFromRequest(ctx, session) + } else { + err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + } if err != nil { return nil, err } - return NewRedirect(session.RedirectURI), nil + return NewRedirect(redirect), nil } diff --git a/pkg/op/session.go b/pkg/op/session.go index c33627f..c933659 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "net/http" "net/url" "path" @@ -68,7 +69,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, } if req.IdTokenHint != "" { claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx)) - if err != nil { + if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) { return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) } session.UserID = claims.GetSubject() diff --git a/pkg/op/storage.go b/pkg/op/storage.go index d083a31..a1a00ed 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -168,15 +168,6 @@ type EndSessionRequest struct { var ErrDuplicateUserCode = errors.New("user code already exists") -type DeviceAuthorizationState struct { - ClientID string - Scopes []string - Expires time.Time - Done bool - Subject string - Denied bool -} - type DeviceAuthorizationStorage interface { // StoreDeviceAuthorizationRequest stores a new device authorization request in the database. // User code will be used by the user to complete the login flow and must be unique. diff --git a/pkg/op/token.go b/pkg/op/token.go index 63a01a6..83889f0 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -84,6 +84,8 @@ func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool return req.GetRequestedTokenType() == oidc.RefreshTokenType case RefreshTokenRequest: return true + case *DeviceAuthorizationState: + return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken) default: return false } diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 21b79c3..9c45ef8 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -65,3 +65,8 @@ func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) return req.Token, clientID, nil } + +type IntrospectionRequest struct { + *ClientCredentials + *oidc.IntrospectionRequest +} diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index 6143252..b5ec72e 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "github.com/zitadel/oidc/v3/pkg/oidc" ) @@ -27,8 +28,23 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi return verifier } +type IDTokenHintExpiredError struct { + error +} + +func (e IDTokenHintExpiredError) Unwrap() error { + return e.error +} + +func (e IDTokenHintExpiredError) Is(err error) bool { + return errors.Is(err, e.error) +} + // VerifyIDTokenHint validates the id token according to -// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. +// In case of an expired token both the Claims and first encountered expiry related error +// is returned of type [IDTokenHintExpiredError]. In that case the caller can choose to still +// trust the token for cases like logout, as signature and other verifications succeeded. func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { var nilClaims C @@ -49,20 +65,20 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTo return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset); err != nil { - return nilClaims, err - } - - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { - return nilClaims, err - } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { + return claims, IDTokenHintExpiredError{err} + } + + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { + return claims, IDTokenHintExpiredError{err} + } + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { - return nilClaims, err + return claims, IDTokenHintExpiredError{err} } return claims, nil } diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index e514a76..597e291 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -2,6 +2,7 @@ package op import ( "context" + "errors" "testing" "time" @@ -57,6 +58,13 @@ func TestNewIDTokenHintVerifier(t *testing.T) { } } +func Test_IDTokenHintExpiredError(t *testing.T) { + var err error = IDTokenHintExpiredError{oidc.ErrExpired} + assert.True(t, errors.Unwrap(err) == oidc.ErrExpired) + assert.ErrorIs(t, err, oidc.ErrExpired) + assert.ErrorAs(t, err, &IDTokenHintExpiredError{}) +} + func TestVerifyIDTokenHint(t *testing.T) { verifier := &IDTokenHintVerifier{ Issuer: tu.ValidIssuer, @@ -71,21 +79,23 @@ func TestVerifyIDTokenHint(t *testing.T) { tests := []struct { name string tokenClaims func() (string, *oidc.IDTokenClaims) - wantErr bool + wantClaims bool + wantErr error }{ { name: "success", tokenClaims: tu.ValidIDToken, + wantClaims: true, }, { name: "parse err", tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, - wantErr: true, + wantErr: oidc.ErrParse, }, { name: "invalid signature", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, - wantErr: true, + wantErr: oidc.ErrSignatureUnsupportedAlg, }, { name: "wrong issuer", @@ -96,29 +106,7 @@ func TestVerifyIDTokenHint(t *testing.T) { tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, - }, - { - name: "expired", - tokenClaims: func() (string, *oidc.IDTokenClaims) { - return tu.NewIDToken( - tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, - tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, - tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", - ) - }, - wantErr: true, - }, - { - name: "wrong IAT", - tokenClaims: func() (string, *oidc.IDTokenClaims) { - return tu.NewIDToken( - tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, - tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, - tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "", - ) - }, - wantErr: true, + wantErr: oidc.ErrIssuerInvalid, }, { name: "wrong acr", @@ -129,7 +117,31 @@ func TestVerifyIDTokenHint(t *testing.T) { "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, + wantErr: oidc.ErrAcrInvalid, + }, + { + name: "expired", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrExpired}, + }, + { + name: "IAT too old", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, time.Hour, "", + ) + }, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrIatToOld}, }, { name: "expired auth", @@ -140,7 +152,8 @@ func TestVerifyIDTokenHint(t *testing.T) { tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrAuthTimeToOld}, }, } for _, tt := range tests { @@ -148,14 +161,12 @@ func TestVerifyIDTokenHint(t *testing.T) { token, want := tt.tokenClaims() got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier) - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, got) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantClaims { + assert.Equal(t, got, want, "claims") return } - require.NoError(t, err) - require.NotNil(t, got) - assert.Equal(t, got, want) + assert.Nil(t, got, "claims") }) } } diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 3b13665..38b8ee4 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -17,11 +17,21 @@ import ( type JWTProfileVerifier struct { oidc.Verifier Storage JWTProfileKeyStorage + keySet oidc.KeySet CheckSubject func(request *oidc.JWTTokenRequest) error } // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + return newJWTProfileVerifier(storage, nil, issuer, maxAgeIAT, offset, opts...) +} + +// NewJWTProfileVerifierKeySet creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) +func NewJWTProfileVerifierKeySet(keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + return newJWTProfileVerifier(nil, keySet, issuer, maxAgeIAT, offset, opts...) +} + +func newJWTProfileVerifier(storage JWTProfileKeyStorage, keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { j := &JWTProfileVerifier{ Verifier: oidc.Verifier{ Issuer: issuer, @@ -29,6 +39,7 @@ func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIA Offset: offset, }, Storage: storage, + keySet: keySet, CheckSubject: SubjectIsIssuer, } @@ -78,7 +89,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVeri return nil, err } - keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} + keySet := v.keySet + if keySet == nil { + keySet = &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} + } if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { return nil, err }