diff --git a/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml b/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 0000000..d024341 --- /dev/null +++ b/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,57 @@ +name: Bug Report +description: "Create a bug report to help us improve ZITADEL. Click [here](https://github.com/zitadel/zitadel/blob/main/CONTRIBUTING.md#product-management) to see how we process your issue." +title: "[Bug]: " +labels: ["bug"] +type: Bug +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! + - type: checkboxes + id: preflight + attributes: + label: Preflight Checklist + options: + - label: + I could not find a solution in the documentation, the existing issues or discussions + required: true + - label: + I have joined the [ZITADEL chat](https://zitadel.com/chat) + - type: input + id: version + attributes: + label: Version + description: Which version of the OIDC library are you using. + - type: textarea + id: impact + attributes: + label: Describe the problem caused by this bug + description: A clear and concise description of the problem you have and what the bug is. + validations: + required: true + - type: textarea + id: reproduce + attributes: + label: To reproduce + description: Steps to reproduce the behaviour + placeholder: | + Steps to reproduce the behavior: + validations: + required: true + - type: textarea + id: screenshots + attributes: + label: Screenshots + description: If applicable, add screenshots to help explain your problem. + - type: textarea + id: expected + attributes: + label: Expected behavior + description: A clear and concise description of what you expected to happen. + placeholder: As a [type of user], I want [some goal] so that [some reason]. + - type: textarea + id: additional + attributes: + label: Additional Context + description: Please add any other infos that could be useful. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.forgejo.bak/ISSUE_TEMPLATE/config.yml similarity index 100% rename from .github/ISSUE_TEMPLATE/config.yml rename to .forgejo.bak/ISSUE_TEMPLATE/config.yml diff --git a/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml b/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml new file mode 100644 index 0000000..d3f82b9 --- /dev/null +++ b/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml @@ -0,0 +1,31 @@ +name: 📄 Documentation +description: Create an issue for missing or wrong documentation. +labels: ["docs"] +type: task +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this issue. + - type: checkboxes + id: preflight + attributes: + label: Preflight Checklist + options: + - label: + I could not find a solution in the existing issues, docs, nor discussions + required: true + - label: + I have joined the [ZITADEL chat](https://zitadel.com/chat) + - type: textarea + id: docs + attributes: + label: Describe the docs your are missing or that are wrong + placeholder: As a [type of user], I want [some goal] so that [some reason]. + validations: + required: true + - type: textarea + id: additional + attributes: + label: Additional Context + description: Please add any other infos that could be useful. diff --git a/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml b/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml new file mode 100644 index 0000000..ef2103e --- /dev/null +++ b/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml @@ -0,0 +1,55 @@ +name: đŸ› ī¸ Improvement +description: "Create an new issue for an improvment in ZITADEL" +labels: ["enhancement"] +type: enhancement +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this proposal / feature reqeust + - type: checkboxes + id: preflight + attributes: + label: Preflight Checklist + options: + - label: + I could not find a solution in the existing issues, docs, nor discussions + required: true + - label: + I have joined the [ZITADEL chat](https://zitadel.com/chat) + - type: textarea + id: problem + attributes: + label: Describe your problem + description: Please describe your problem this improvement is supposed to solve. + placeholder: Describe the problem you have + validations: + required: true + - type: textarea + id: solution + attributes: + label: Describe your ideal solution + description: Which solution do you propose? + placeholder: As a [type of user], I want [some goal] so that [some reason]. + validations: + required: true + - type: input + id: version + attributes: + label: Version + description: Which version of the OIDC Library are you using. + - type: dropdown + id: environment + attributes: + label: Environment + description: How do you use ZITADEL? + options: + - ZITADEL Cloud + - Self-hosted + validations: + required: true + - type: textarea + id: additional + attributes: + label: Additional Context + description: Please add any other infos that could be useful. diff --git a/.github/dependabot.yml b/.forgejo.bak/dependabot.yml similarity index 58% rename from .github/dependabot.yml rename to .forgejo.bak/dependabot.yml index 79ff704..1efdcf8 100644 --- a/.github/dependabot.yml +++ b/.forgejo.bak/dependabot.yml @@ -9,6 +9,16 @@ updates: commit-message: prefix: chore include: scope +- package-ecosystem: gomod + target-branch: "2.12.x" + directory: "/" + schedule: + interval: daily + time: '04:00' + open-pull-requests-limit: 10 + commit-message: + prefix: chore + include: scope - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.forgejo.bak/pull_request_template.md b/.forgejo.bak/pull_request_template.md new file mode 100644 index 0000000..6c4ae58 --- /dev/null +++ b/.forgejo.bak/pull_request_template.md @@ -0,0 +1,16 @@ +### Definition of Ready + +- [ ] I am happy with the code +- [ ] Short description of the feature/issue is added in the pr description +- [ ] PR is linked to the corresponding user story +- [ ] Acceptance criteria are met +- [ ] All open todos and follow ups are defined in a new ticket and justified +- [ ] Deviations from the acceptance criteria and design are agreed with the PO and documented. +- [ ] No debug or dead code +- [ ] My code has no repetitions +- [ ] Critical parts are tested automatically +- [ ] Where possible E2E tests are implemented +- [ ] Documentation/examples are up-to-date +- [ ] All non-functional requirements are met +- [ ] Functionality of the acceptance criteria is checked manually on the dev system. + diff --git a/.github/workflows/codeql-analysis.yml b/.forgejo.bak/workflows/codeql-analysis.yml similarity index 90% rename from .github/workflows/codeql-analysis.yml rename to .forgejo.bak/workflows/codeql-analysis.yml index d2bae79..27fa244 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.forgejo.bak/workflows/codeql-analysis.yml @@ -16,7 +16,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: # We must fetch at least the immediate parents so that if this is # a pull request then we can checkout the head. @@ -29,7 +29,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 # Override language selection by uncommenting this and choosing your languages with: languages: go @@ -37,7 +37,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -51,4 +51,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.forgejo.bak/workflows/issue.yml b/.forgejo.bak/workflows/issue.yml new file mode 100644 index 0000000..480c339 --- /dev/null +++ b/.forgejo.bak/workflows/issue.yml @@ -0,0 +1,43 @@ +name: Add new issues to product management project + +on: + issues: + types: + - opened + pull_request_target: + types: + - opened + +jobs: + add-to-project: + name: Add issue and community pr to project + runs-on: ubuntu-latest + steps: + - name: add issue + uses: actions/add-to-project@v1.0.2 + if: ${{ github.event_name == 'issues' }} + with: + # You can target a repository in a different organization + # to the issue + project-url: https://github.com/orgs/zitadel/projects/2 + github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} + - uses: tspascoal/get-user-teams-membership@v3 + id: checkUserMember + if: github.actor != 'dependabot[bot]' + with: + username: ${{ github.actor }} + GITHUB_TOKEN: ${{ secrets.ADD_TO_PROJECT_PAT }} + - name: add pr + uses: actions/add-to-project@v1.0.2 + if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'engineers')}} + with: + # You can target a repository in a different organization + # to the issue + project-url: https://github.com/orgs/zitadel/projects/2 + github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} + - uses: actions-ecosystem/action-add-labels@v1.1.3 + if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'staff')}} + with: + github_token: ${{ secrets.ADD_TO_PROJECT_PAT }} + labels: | + os-contribution diff --git a/.github/workflows/release.yml b/.forgejo.bak/workflows/release.yml similarity index 74% rename from .github/workflows/release.yml rename to .forgejo.bak/workflows/release.yml index 78c0f79..00063e4 100644 --- a/.github/workflows/release.yml +++ b/.forgejo.bak/workflows/release.yml @@ -2,6 +2,7 @@ name: Release on: push: branches: + - "2.11.x" - main - next tags-ignore: @@ -13,33 +14,34 @@ on: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: + fail-fast: false matrix: - go: ['1.18', '1.19', '1.20'] + go: ['1.23', '1.24'] name: Go ${{ matrix.go }} test steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/... - - uses: codecov/codecov-action@v3.1.1 + - uses: codecov/codecov-action@v5.4.3 with: file: ./profile.cov name: codecov-go release: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 needs: [test] if: ${{ github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next' }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - name: Source checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Semantic Release - uses: cycjimmy/semantic-release-action@v3 + uses: cycjimmy/semantic-release-action@v4 with: dry_run: false semantic_version: 18.0.1 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 49ccc49..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,38 +0,0 @@ ---- -name: 🐛 Bug report -about: Create a report to help us improve -title: '' -labels: bug -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - -**Desktop (please complete the following information):** -- OS: [e.g. iOS] -- Browser [e.g. chrome, safari] -- Version [e.g. 22] - -**Smartphone (please complete the following information):** -- Device: [e.g. iPhone6] -- OS: [e.g. iOS8.1] -- Browser [e.g. stock browser, safari] -- Version [e.g. 22] - -**Additional context** -Add any other context about the problem here. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 118d30e..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: 🚀 Feature request -about: Suggest an idea for this project -title: '' -labels: enhancement -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/issue.yml b/.github/workflows/issue.yml deleted file mode 100644 index 8671820..0000000 --- a/.github/workflows/issue.yml +++ /dev/null @@ -1,18 +0,0 @@ -name: Add new issues to product management project - -on: - issues: - types: - - opened - -jobs: - add-to-project: - name: Add issue to project - runs-on: ubuntu-latest - steps: - - uses: actions/add-to-project@v0.4.1 - with: - # You can target a repository in a different organization - # to the issue - project-url: https://github.com/orgs/zitadel/projects/2 - github-token: ${{ secrets.ADD_TO_PROJECT_PAT }} diff --git a/.releaserc.js b/.releaserc.js index e8eea8e..c87b1d1 100644 --- a/.releaserc.js +++ b/.releaserc.js @@ -1,5 +1,6 @@ module.exports = { branches: [ + {name: "2.11.x"}, {name: "main"}, {name: "next", prerelease: true}, ], diff --git a/README.md b/README.md index f369a5c..bc346f5 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ [![semantic-release](https://img.shields.io/badge/%20%20%F0%9F%93%A6%F0%9F%9A%80-semantic--release-e10079.svg)](https://github.com/semantic-release/semantic-release) [![Release](https://github.com/zitadel/oidc/workflows/Release/badge.svg)](https://github.com/zitadel/oidc/actions) -[![GoDoc](https://godoc.org/github.com/zitadel/oidc?status.png)](https://pkg.go.dev/github.com/zitadel/oidc) +[![Go Reference](https://pkg.go.dev/badge/github.com/zitadel/oidc/v3.svg)](https://pkg.go.dev/github.com/zitadel/oidc/v3) [![license](https://badgen.net/github/license/zitadel/oidc/)](https://github.com/zitadel/oidc/blob/master/LICENSE) [![release](https://badgen.net/github/release/zitadel/oidc/stable)](https://github.com/zitadel/oidc/releases) -[![Go Report Card](https://goreportcard.com/badge/github.com/zitadel/oidc)](https://goreportcard.com/report/github.com/zitadel/oidc) +[![Go Report Card](https://goreportcard.com/badge/github.com/zitadel/oidc/v3)](https://goreportcard.com/report/github.com/zitadel/oidc/v3) [![codecov](https://codecov.io/gh/zitadel/oidc/branch/main/graph/badge.svg)](https://codecov.io/gh/zitadel/oidc) -![openid_certified](https://cloud.githubusercontent.com/assets/1454075/7611268/4d19de32-f97b-11e4-895b-31b2455a7ca6.png) +[![openid_certified](https://cloud.githubusercontent.com/assets/1454075/7611268/4d19de32-f97b-11e4-895b-31b2455a7ca6.png)](https://openid.net/certification/) ## What Is It @@ -21,9 +21,10 @@ Whenever possible we tried to reuse / extend existing packages like `OAuth2 for ## Basic Overview The most important packages of the library: +
 /pkg
-    /client            clients using the OP for retrieving, exchanging and verifying tokens       
+    /client            clients using the OP for retrieving, exchanging and verifying tokens
         /rp            definition and implementation of an OIDC Relying Party (client)
         /rs            definition and implementation of an OAuth Resource Server (API)
     /op                definition and implementation of an OIDC OpenID Provider (server)
@@ -37,6 +38,10 @@ The most important packages of the library:
     /server            examples of an OpenID Provider implementations (including dynamic) with some very basic login UI
 
+### Semver + +This package uses [semver](https://semver.org/) for [releases](https://github.com/zitadel/oidc/releases). Major releases ship breaking changes. Starting with the `v2` to `v3` increment we provide an [upgrade guide](UPGRADING.md) to ease migration to a newer version. + ## How To Use It Check the `/example` folder where example code for different scenarios is located. @@ -44,54 +49,90 @@ Check the `/example` folder where example code for different scenarios is locate ```bash # start oidc op server # oidc discovery http://localhost:9998/.well-known/openid-configuration -go run github.com/zitadel/oidc/v2/example/server +go run github.com/zitadel/oidc/v3/example/server # start oidc web client (in a new terminal) -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` - open http://localhost:9999/login in your browser -- you will be redirected to op server and the login UI +- you will be redirected to op server and the login UI - login with user `test-user@localhost` and password `verysecure` - the OP will redirect you to the client app, which displays the user info for the dynamic issuer, just start it with: + ```bash -go run github.com/zitadel/oidc/v2/example/server/dynamic -``` +go run github.com/zitadel/oidc/v3/example/server/dynamic +``` + the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with: + ```bash -CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v2/example/client/app +CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app ``` > Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`) +### Server configuration + +Example server allows extra configuration using environment variables and could be used for end to +end testing of your services. + +| Name | Format | Description | +| ------------ | -------------------------------- | ------------------------------------- | +| PORT | Number between 1 and 65535 | OIDC listen port | +| REDIRECT_URI | Comma-separated URIs | List of allowed redirect URIs | +| USERS_FILE | Path to json in local filesystem | Users with their data and credentials | + +Here is json equivalent for one of the default users + +```json +{ + "id2": { + "ID": "id2", + "Username": "test-user2", + "Password": "verysecure", + "FirstName": "Test", + "LastName": "User2", + "Email": "test-user2@zitadel.ch", + "EmailVerified": true, + "Phone": "", + "PhoneVerified": false, + "PreferredLanguage": "DE", + "IsAdmin": false + } +} +``` + ## Features -| | Relying party | OpenID Provider | Specification | -| -------------------- | ------------- | --------------- | ----------------------------------------- | -| Code Flow | yes | yes | OpenID Connect Core 1.0, [Section 3.1][1] | -| Implicit Flow | no[^1] | yes | OpenID Connect Core 1.0, [Section 3.2][2] | -| Hybrid Flow | no | not yet | OpenID Connect Core 1.0, [Section 3.3][3] | -| Client Credentials | not yet | yes | OpenID Connect Core 1.0, [Section 9][4] | -| Refresh Token | yes | yes | OpenID Connect Core 1.0, [Section 12][5] | -| Discovery | yes | yes | OpenID Connect [Discovery][6] 1.0 | -| JWT Profile | yes | yes | [RFC 7523][7] | -| PKCE | yes | yes | [RFC 7636][8] | -| Token Exchange | yes | yes | [RFC 8693][9] | -| Device Authorization | yes | yes | [RFC 8628][10] | -| mTLS | not yet | not yet | [RFC 8705][11] | +| | Relying party | OpenID Provider | Specification | +| -------------------- | ------------- | --------------- | -------------------------------------------- | +| Code Flow | yes | yes | OpenID Connect Core 1.0, [Section 3.1][1] | +| Implicit Flow | no[^1] | yes | OpenID Connect Core 1.0, [Section 3.2][2] | +| Hybrid Flow | no | not yet | OpenID Connect Core 1.0, [Section 3.3][3] | +| Client Credentials | yes | yes | OpenID Connect Core 1.0, [Section 9][4] | +| Refresh Token | yes | yes | OpenID Connect Core 1.0, [Section 12][5] | +| Discovery | yes | yes | OpenID Connect [Discovery][6] 1.0 | +| JWT Profile | yes | yes | [RFC 7523][7] | +| PKCE | yes | yes | [RFC 7636][8] | +| Token Exchange | yes | yes | [RFC 8693][9] | +| Device Authorization | yes | yes | [RFC 8628][10] | +| mTLS | not yet | not yet | [RFC 8705][11] | +| Back-Channel Logout | not yet | yes | OpenID Connect [Back-Channel Logout][12] 1.0 | -[1]: "3.1. Authentication using the Authorization Code Flow" -[2]: "3.2. Authentication using the Implicit Flow" -[3]: "3.3. Authentication using the Hybrid Flow" -[4]: "9. Client Authentication" -[5]: "12. Using Refresh Tokens" -[6]: "OpenID Connect Discovery 1.0 incorporating errata set 1" -[7]: "JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants" -[8]: "Proof Key for Code Exchange by OAuth Public Clients" -[9]: "OAuth 2.0 Token Exchange" -[10]: "OAuth 2.0 Device Authorization Grant" -[11]: "OAuth 2.0 Mutual-TLS Client Authentication and Certificate-Bound Access Tokens" +[1]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth "3.1. Authentication using the Authorization Code Flow" +[2]: https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth "3.2. Authentication using the Implicit Flow" +[3]: https://openid.net/specs/openid-connect-core-1_0.html#HybridFlowAuth "3.3. Authentication using the Hybrid Flow" +[4]: https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication "9. Client Authentication" +[5]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens "12. Using Refresh Tokens" +[6]: https://openid.net/specs/openid-connect-discovery-1_0.html "OpenID Connect Discovery 1.0 incorporating errata set 1" +[7]: https://www.rfc-editor.org/rfc/rfc7523.html "JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants" +[8]: https://www.rfc-editor.org/rfc/rfc7636.html "Proof Key for Code Exchange by OAuth Public Clients" +[9]: https://www.rfc-editor.org/rfc/rfc8693.html "OAuth 2.0 Token Exchange" +[10]: https://www.rfc-editor.org/rfc/rfc8628.html "OAuth 2.0 Device Authorization Grant" +[11]: https://www.rfc-editor.org/rfc/rfc8705.html "OAuth 2.0 Mutual-TLS Client Authentication and Certificate-Bound Access Tokens" +[12]: https://openid.net/specs/openid-connect-backchannel-1_0.html "OpenID Connect Back-Channel Logout 1.0 incorporating errata set 1" ## Contributors @@ -110,15 +151,14 @@ For your convenience you can find the relevant guides linked below. ## Supported Go Versions -For security reasons, we only support and recommend the use of one of the latest two Go versions (:white_check_mark:). +For security reasons, we only support and recommend the use of one of the latest two Go versions (:white_check_mark:). Versions that also build are marked with :warning:. | Version | Supported | | ------- | ------------------ | -| <1.18 | :x: | -| 1.18 | :warning: | -| 1.19 | :white_check_mark: | -| 1.20 | :white_check_mark: | +| <1.23 | :x: | +| 1.23 | :white_check_mark: | +| 1.24 | :white_check_mark: | ## Why another library @@ -149,5 +189,4 @@ Unless required by applicable law or agreed to in writing, software distributed AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - [^1]: https://github.com/zitadel/oidc/issues/135#issuecomment-950563892 diff --git a/SECURITY.md b/SECURITY.md index 934426a..a32b842 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,43 +1,20 @@ # Security Policy -At ZITADEL we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team. +Please refer to the security policy [on zitadel/zitadel](https://github.com/zitadel/zitadel/blob/main/SECURITY.md) which is applicable for all open source repositories of our organization. ## Supported Versions -After the initial Release the following version support will apply +We currently support the following version of the OIDC framework: -| Version | Supported | -| ------- | ------------------ | -| 0.x.x | :x: | -| 1.x.x | :white_check_mark: | -| 2.x.x | :white_check_mark: (not released) | +| Version | Supported | Branch | Details | +| -------- | ------------------ | ----------- | ------------------------------------ | +| 0.x.x | :x: | | not maintained | +| <2.11 | :x: | | not maintained | +| 2.11.x | :lock: :warning: | [2.11.x][1] | security only, [community effort][2] | +| 3.x.x | :heavy_check_mark: | [main][3] | supported | +| 4.0.0-xx | :white_check_mark: | [next][4] | [development branch] | -## Reporting a vulnerability - -To file a incident, please disclose by email to security@zitadel.com with the security details. - -At the moment GPG encryption is no yet supported, however you may sign your message at will. - -### When should I report a vulnerability - -* You think you discovered a ... - * ... potential security vulnerability in the SDK - * ... vulnerability in another project that this SDK bases on -* For projects with their own vulnerability reporting and disclosure process, please report it directly there - -### When should I NOT report a vulnerability - -* You need help applying security related updates -* Your issue is not security related - -## Security Vulnerability Response - -TBD - -## Public Disclosure - -All accepted and mitigated vulnerabilities will be published on the [Github Security Page](https://github.com/zitadel/oidc/security/advisories) - -### Timing - -We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the disclosures the time frame can range from 7 to 90 days. +[1]: https://github.com/zitadel/oidc/tree/2.11.x +[2]: https://github.com/zitadel/oidc/discussions/458 +[3]: https://github.com/zitadel/oidc/tree/main +[4]: https://github.com/zitadel/oidc/tree/next diff --git a/UPGRADING.md b/UPGRADING.md new file mode 100644 index 0000000..6b5a41d --- /dev/null +++ b/UPGRADING.md @@ -0,0 +1,370 @@ +# Upgrading + +All commands are executed from the root of the project that imports oidc packages. +`sed` commands are created with **GNU sed** in mind and might need alternate syntax +on non-GNU systems, such as MacOS. +Alternatively, GNU sed can be installed on such systems. (`coreutils` package?). + +## V2 to V3 + +**TL;DR** at the [bottom](#full-script) of this chapter is a full `sed` script +containing all automatic steps at once. + + +As first steps we will: +1. Download the latest v3 module; +2. Replace imports in all Go files; +3. Tidy the module file; + +```bash +go get -u github.com/zitadel/oidc/v3 +find . -type f -name '*.go' | xargs sed -i \ + -e 's/github\.com\/zitadel\/oidc\/v2/github.com\/zitadel\/oidc\/v3/g' +go mod tidy +``` + +### global + +#### go-jose package + +`gopkg.in/square/go-jose.v2` import has been changed to `github.com/go-jose/go-jose/v3`. +That means that the imported types are also changed and imports need to be adapted. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/gopkg.in\/square\/go-jose\.v2/github.com\/go-jose\/go-jose\/v3/g' +go mod tidy +``` + +### op + +```go +import "github.com/zitadel/oidc/v3/pkg/op" +``` + +#### Logger + +This version of OIDC adds logging to the framework. For this we use the new Go standard library `log/slog`. (Until v3.12.0 we used `x/exp/slog`). +Mostly OIDC will use error level logs where it's returning an error through a HTTP handler. OIDC errors that are user facing don't carry much context, also for security reasons. With logging we are now able to print the error context, so that developers can more easily find the source of their issues. Previously we just discarded such context. + +Most users of the OP package with the storage interface will not experience breaking changes. However if you use `RequestError()` directly in your code, you now need to give it a `Logger` as final argument. + +The `OpenIDProvider` and sub-interfaces like `Authorizer` and `Exchanger` got a `Logger()` method to return the configured logger. This logger is in turn used by `AuthRequestError()`. You configure the logger with the `WithLogger()` for the `Provider`. By default the `slog.Default()` is used. + +We also provide a new optional interface: [`LogAuthRequest`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#LogAuthRequest). If an `AuthRequest` implements this interface, it is completely passed into the logger after an error. Its `LogValue()` will be used by `slog` to print desired fields. This allows omitting sensitive fields you wish not no print. If the interface is not implemented, no `AuthRequest` details will ever be printed. + +#### Server interface + +We've added a new [`Server`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#Server) interface. This interface is experimental and subject to change. See [issue 440](https://github.com/zitadel/oidc/issues/440) for the motivation and discussion around this new interface. +Usage of the new interface is not required, but may be used for advanced scenarios when working with the `Storage` interface isn't the optimal solution for your app (like we experienced in [Zitadel](https://github.com/zitadel/zitadel)). + +#### AuthRequestError + +`AuthRequestError` now takes the complete `Authorizer` as final argument, instead of only the encoder. +This is to facilitate the use of the `Logger` as described above. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/\bAuthRequestError(w, r, authReq, err, authorizer.Encoder())/AuthRequestError(w, r, authReq, err, authorizer)/g' +``` + +Note: the sed regex might not find all uses if the local variables of the passed arguments use different names. + +#### AccessTokenVerifier + +`AccessTokenVerifier` interface has become a struct type. `NewAccessTokenVerifier` now returns a pointer to `AccessTokenVerifier`. +Variable and struct fields declarations need to be changed from `op.AccessTokenVerifier` to `*op.AccessTokenVerifier`. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/\bop\.AccessTokenVerifier\b/*op.AccessTokenVerifier/g' +``` + +#### JWTProfileVerifier + +`JWTProfileVerifier` interface has become a struct type. `NewJWTProfileVerifier` now returns a pointer to `JWTProfileVerifier`. +Variable and struct fields declarations need to be changed from `op.JWTProfileVerifier` to `*op.JWTProfileVerifier`. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/\bop\.JWTProfileVerifier\b/*op.JWTProfileVerifier/g' +``` + +#### IDTokenHintVerifier + +`IDTokenHintVerifier` interface has become a struct type. `NewIDTokenHintVerifier` now returns a pointer to `IDTokenHintVerifier`. +Variable and struct fields declarations need to be changed from `op.IDTokenHintVerifier` to `*op.IDTokenHintVerifier`. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/\bop\.IDTokenHintVerifier\b/*op.IDTokenHintVerifier/g' +``` + +#### ParseRequestObject + +`ParseRequestObject` no longer returns `*oidc.AuthRequest` as it already operates on the pointer for the passed `authReq` argument. As such the argument and the return value were the same pointer. Callers can just use the original `*oidc.AuthRequest` now. + +#### Endpoint Configuration + +`Endpoint`s returned from `Configuration` interface methods are now pointers. Usually, `op.Provider` is the main implementation of the `Configuration` interface. However, if a custom implementation is used, you should be able to update it using the following: + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/AuthorizationEndpoint() Endpoint/AuthorizationEndpoint() *Endpoint/g' \ + -e 's/TokenEndpoint() Endpoint/TokenEndpoint() *Endpoint/g' \ + -e 's/IntrospectionEndpoint() Endpoint/IntrospectionEndpoint() *Endpoint/g' \ + -e 's/UserinfoEndpoint() Endpoint/UserinfoEndpoint() *Endpoint/g' \ + -e 's/RevocationEndpoint() Endpoint/RevocationEndpoint() *Endpoint/g' \ + -e 's/EndSessionEndpoint() Endpoint/EndSessionEndpoint() *Endpoint/g' \ + -e 's/KeysEndpoint() Endpoint/KeysEndpoint() *Endpoint/g' \ + -e 's/DeviceAuthorizationEndpoint() Endpoint/DeviceAuthorizationEndpoint() *Endpoint/g' +``` + +#### CreateDiscoveryConfig + +`CreateDiscoveryConfig` now takes a context as first argument. The following adds `context.TODO()` to the function: + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/op\.CreateDiscoveryConfig(/op.CreateDiscoveryConfig(context.TODO(), /g' +``` + +It now takes the issuer out of the context using the [`IssuerFromContext`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#IssuerFromContext) functionality, +instead of the `config.IssuerFromRequest()` method. + +#### CreateRouter + +`CreateRouter` now returns a `chi.Router` instead of `*mux.Router`. +Usually this function is called when the Provider is constructed and not by package consumers. +However if your project does call this function directly, manual update of the code is required. + +#### DeviceAuthorizationStorage + +`DeviceAuthorizationStorage` dropped the following methods: + +- `GetDeviceAuthorizationByUserCode` +- `CompleteDeviceAuthorization` +- `DenyDeviceAuthorization` + +These methods proved not to be required from a library point of view. +Implementations of a device authorization flow may take care of these calls in a way they see fit. + +#### AuthorizeCodeChallenge + +The `AuthorizeCodeChallenge` function now only takes the `CodeVerifier` argument, instead of the complete `*oidc.AccessTokenRequest`. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/op\.AuthorizeCodeChallenge(tokenReq/op.AuthorizeCodeChallenge(tokenReq.CodeVerifier/g' +``` + +### client + +```go +import "github.com/zitadel/oidc/v3/pkg/client" +``` + +#### Context + +All client calls now take a context as first argument. The following adds `context.TODO()` to all the affected functions: + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/client\.Discover(/client.Discover(context.TODO(), /g' \ + -e 's/client\.CallTokenEndpoint(/client.CallTokenEndpoint(context.TODO(), /g' \ + -e 's/client\.CallEndSessionEndpoint(/client.CallEndSessionEndpoint(context.TODO(), /g' \ + -e 's/client\.CallRevokeEndpoint(/client.CallRevokeEndpoint(context.TODO(), /g' \ + -e 's/client\.CallTokenExchangeEndpoint(/client.CallTokenExchangeEndpoint(context.TODO(), /g' \ + -e 's/client\.CallDeviceAuthorizationEndpoint(/client.CallDeviceAuthorizationEndpoint(context.TODO(), /g' \ + -e 's/client\.JWTProfileExchange(/client.JWTProfileExchange(context.TODO(), /g' +``` + +#### keyFile type + +The `keyFile` struct type is now exported a `KeyFile` and returned by the `ConfigFromKeyFile` and `ConfigFromKeyFileData`. No changes are needed on the caller's side. + +### client/profile + +The package now defines a new interface `TokenSource` which compliments the `oauth2.TokenSource` with a `TokenCtx` method, so that a context can be explicitly added on each call. Users can migrate to the new method when they whish. + +`NewJWTProfileTokenSource` now takes a context as first argument, so do the related `NewJWTProfileTokenSourceFromKeyFile` and `NewJWTProfileTokenSourceFromKeyFileData`. The context is used for the Discovery request. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/profile\.NewJWTProfileTokenSource(/profile.NewJWTProfileTokenSource(context.TODO(), /g' \ + -e 's/profile\.NewJWTProfileTokenSourceFromKeyFileData(/profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), /g' \ + -e 's/profile\.NewJWTProfileTokenSourceFromKeyFile(/profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), /g' +``` + + +### client/rp + +```go +import "github.com/zitadel/oidc/v3/pkg/client/rs" +``` + +#### Discover + +The `Discover` function has been removed. Use `client.Discover` instead. + +#### Context + +Most `rp` functions now require a context as first argument. The following adds `context.TODO()` to the function that have no additional changes. Functions with more complex changes are documented below. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/rp\.NewRelyingPartyOIDC(/rp.NewRelyingPartyOIDC(context.TODO(), /g' \ + -e 's/rp\.EndSession(/rp.EndSession(context.TODO(), /g' \ + -e 's/rp\.RevokeToken(/rp.RevokeToken(context.TODO(), /g' \ + -e 's/rp\.DeviceAuthorization(/rp.DeviceAuthorization(context.TODO(), /g' +``` + +Remember to replace `context.TODO()` with a context that is applicable for your app, where possible. + +#### RefreshAccessToken + +1. Renamed to `RefreshTokens`; +2. A context must be passed; +3. An `*oidc.Tokens` object is now returned, which included an ID Token if it was returned by the server; +4. The function is now generic and requires a type argument for the `IDTokenClaims` implementation inside the returned `oidc.Tokens` object; + +For most use cases `*oidc.IDTokenClaims` can be used as type argument. A custom implementation of `oidc.IDClaims` can be used if type-safe access to custom claims is required. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/rp\.RefreshAccessToken(/rp.RefreshTokens[*oidc.IDTokenClaims](context.TODO(), /g' +``` + +Users that called `tokens.Extra("id_token").(string)` and a subsequent `VerifyTokens` to get the claims, no longer need to do this. The ID token is verified (when present) by `RefreshTokens` already. + + +#### Userinfo + +1. A context must be passed as first argument; +2. The function is now generic and requires a type argument for the returned user info object; + +For most use cases `*oidc.UserInfo` can be used a type argument. A [custom implementation](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/client/rp#example-Userinfo-Custom) of `rp.SubjectGetter` can be used if type-safe access to custom claims is required. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/rp\.Userinfo(/rp.Userinfo[*oidc.UserInfo](context.TODO(), /g' +``` + +#### UserinfoCallback + +`UserinfoCallback` has an additional type argument fot the `UserInfo` object. Typically the type argument can be inferred by the compiler, by the function that is passed. The actual code update cannot be done by a simple `sed` script and depends on how the caller implemented the function. + + +#### IDTokenVerifier + +`IDTokenVerifier` interface has become a struct type. `NewIDTokenVerifier` now returns a pointer to `IDTokenVerifier`. +Variable and struct fields declarations need to be changed from `rp.IDTokenVerifier` to `*rp.AccessTokenVerifier`. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/\brp\.IDTokenVerifier\b/*rp.IDTokenVerifier/g' +``` + +### client/rs + +```go +import "github.com/zitadel/oidc/v3/pkg/client/rs" +``` + +#### NewResourceServer + +The `NewResourceServerClientCredentials` and `NewResourceServerJWTProfile` constructor functions now take a context as first argument. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/rs\.NewResourceServerClientCredentials(/rs.NewResourceServerClientCredentials(context.TODO(), /g' \ + -e 's/rs\.NewResourceServerJWTProfile(/rs.NewResourceServerJWTProfile(context.TODO(), /g' +``` + +#### Introspect + +`Introspect` is now generic and requires a type argument for the returned introspection response. For most use cases `*oidc.IntrospectionResponse` can be used as type argument. Any other response type if type-safe access to [custom claims](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/client/rs#example-Introspect-Custom) is required. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/rs\.Introspect(/rs.Introspect[*oidc.IntrospectionResponse](/g' +``` + +### client/tokenexchange + +The `TokenExchanger` constructor functions `NewTokenExchanger` and `NewTokenExchangerClientCredentials` now take a context as first argument. +As well as the `ExchangeToken` function. + +```bash +find . -type f -name '*.go' | xargs sed -i \ + -e 's/tokenexchange\.NewTokenExchanger(/tokenexchange.NewTokenExchanger(context.TODO(), /g' \ + -e 's/tokenexchange\.NewTokenExchangerClientCredentials(/tokenexchange.NewTokenExchangerClientCredentials(context.TODO(), /g' \ + -e 's/tokenexchange\.ExchangeToken(/tokenexchange.ExchangeToken(context.TODO(), /g' +``` + +### oidc + +#### SpaceDelimitedArray + +The `SpaceDelimitedArray` type's `Encode()` function has been renamed to `String()` so it implements the `fmt.Stringer` interface. If the `Encode` method was called by a package consumer, it should be changed manually. + +#### Verifier + +The `Verifier` interface as been changed into a struct type. The struct type is aliased in the `op` and `rp` packages for the specific token use cases. See the relevant section above. + +### Full script + +For the courageous this is the full `sed` script which combines all the steps described above. +It should migrate most of the code in a repository to a more-or-less compilable state, +using defaults such as `context.TODO()` where possible. + +Warnings: +- Again, this is written for **GNU sed** not the posix variant. +- Assumes imports that use the package names, not aliases. +- Do this on a project with version control (eg Git), that allows you to rollback if things went wrong. +- The script has been tested on the [ZITADEL](https://github.com/zitadel/zitadel) project, but we do not use all affected symbols. Parts of the script are mere guesswork. + +```bash +go get -u github.com/zitadel/oidc/v3 +find . -type f -name '*.go' | xargs sed -i \ + -e 's/github\.com\/zitadel\/oidc\/v2/github.com\/zitadel\/oidc\/v3/g' \ + -e 's/gopkg.in\/square\/go-jose\.v2/github.com\/go-jose\/go-jose\/v3/g' \ + -e 's/\bAuthRequestError(w, r, authReq, err, authorizer.Encoder())/AuthRequestError(w, r, authReq, err, authorizer)/g' \ + -e 's/\bop\.AccessTokenVerifier\b/*op.AccessTokenVerifier/g' \ + -e 's/\bop\.JWTProfileVerifier\b/*op.JWTProfileVerifier/g' \ + -e 's/\bop\.IDTokenHintVerifier\b/*op.IDTokenHintVerifier/g' \ + -e 's/AuthorizationEndpoint() Endpoint/AuthorizationEndpoint() *Endpoint/g' \ + -e 's/TokenEndpoint() Endpoint/TokenEndpoint() *Endpoint/g' \ + -e 's/IntrospectionEndpoint() Endpoint/IntrospectionEndpoint() *Endpoint/g' \ + -e 's/UserinfoEndpoint() Endpoint/UserinfoEndpoint() *Endpoint/g' \ + -e 's/RevocationEndpoint() Endpoint/RevocationEndpoint() *Endpoint/g' \ + -e 's/EndSessionEndpoint() Endpoint/EndSessionEndpoint() *Endpoint/g' \ + -e 's/KeysEndpoint() Endpoint/KeysEndpoint() *Endpoint/g' \ + -e 's/DeviceAuthorizationEndpoint() Endpoint/DeviceAuthorizationEndpoint() *Endpoint/g' \ + -e 's/op\.CreateDiscoveryConfig(/op.CreateDiscoveryConfig(context.TODO(), /g' \ + -e 's/op\.AuthorizeCodeChallenge(tokenReq/op.AuthorizeCodeChallenge(tokenReq.CodeVerifier/g' \ + -e 's/client\.Discover(/client.Discover(context.TODO(), /g' \ + -e 's/client\.CallTokenEndpoint(/client.CallTokenEndpoint(context.TODO(), /g' \ + -e 's/client\.CallEndSessionEndpoint(/client.CallEndSessionEndpoint(context.TODO(), /g' \ + -e 's/client\.CallRevokeEndpoint(/client.CallRevokeEndpoint(context.TODO(), /g' \ + -e 's/client\.CallTokenExchangeEndpoint(/client.CallTokenExchangeEndpoint(context.TODO(), /g' \ + -e 's/client\.CallDeviceAuthorizationEndpoint(/client.CallDeviceAuthorizationEndpoint(context.TODO(), /g' \ + -e 's/client\.JWTProfileExchange(/client.JWTProfileExchange(context.TODO(), /g' \ + -e 's/profile\.NewJWTProfileTokenSource(/profile.NewJWTProfileTokenSource(context.TODO(), /g' \ + -e 's/profile\.NewJWTProfileTokenSourceFromKeyFileData(/profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), /g' \ + -e 's/profile\.NewJWTProfileTokenSourceFromKeyFile(/profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), /g' \ + -e 's/rp\.NewRelyingPartyOIDC(/rp.NewRelyingPartyOIDC(context.TODO(), /g' \ + -e 's/rp\.EndSession(/rp.EndSession(context.TODO(), /g' \ + -e 's/rp\.RevokeToken(/rp.RevokeToken(context.TODO(), /g' \ + -e 's/rp\.DeviceAuthorization(/rp.DeviceAuthorization(context.TODO(), /g' \ + -e 's/rp\.RefreshAccessToken(/rp.RefreshTokens[*oidc.IDTokenClaims](context.TODO(), /g' \ + -e 's/rp\.Userinfo(/rp.Userinfo[*oidc.UserInfo](context.TODO(), /g' \ + -e 's/\brp\.IDTokenVerifier\b/*rp.IDTokenVerifier/g' \ + -e 's/rs\.NewResourceServerClientCredentials(/rs.NewResourceServerClientCredentials(context.TODO(), /g' \ + -e 's/rs\.NewResourceServerJWTProfile(/rs.NewResourceServerJWTProfile(context.TODO(), /g' \ + -e 's/rs\.Introspect(/rs.Introspect[*oidc.IntrospectionResponse](/g' \ + -e 's/tokenexchange\.NewTokenExchanger(/tokenexchange.NewTokenExchanger(context.TODO(), /g' \ + -e 's/tokenexchange\.NewTokenExchangerClientCredentials(/tokenexchange.NewTokenExchangerClientCredentials(context.TODO(), /g' \ + -e 's/tokenexchange\.ExchangeToken(/tokenexchange.ExchangeToken(context.TODO(), /g' +go mod tidy +``` \ No newline at end of file diff --git a/example/client/api/api.go b/example/client/api/api.go index 8093b63..69f9466 100644 --- a/example/client/api/api.go +++ b/example/client/api/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -9,11 +10,11 @@ import ( "strings" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) const ( @@ -27,12 +28,12 @@ func main() { port := os.Getenv("PORT") issuer := os.Getenv("ISSUER") - provider, err := rs.NewResourceServerFromKeyFile(issuer, keyPath) + provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } - router := mux.NewRouter() + router := chi.NewRouter() // public url accessible without any authorization // will print `OK` and current timestamp @@ -47,7 +48,7 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return @@ -68,14 +69,14 @@ func main() { if !ok { return } - resp, err := rs.Introspect(r.Context(), provider, token) + resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token) if err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } - params := mux.Vars(r) - requestedClaim := params["claim"] - requestedValue := params["value"] + requestedClaim := chi.URLParam(r, "claim") + requestedValue := chi.URLParam(r, "value") + value, ok := resp.Claims[requestedClaim].(string) if !ok || value == "" || value != requestedValue { http.Error(w, "claim does not match", http.StatusForbidden) diff --git a/example/client/app/app.go b/example/client/app/app.go index 0c324d2..90b1969 100644 --- a/example/client/app/app.go +++ b/example/client/app/app.go @@ -1,19 +1,23 @@ package main import ( + "context" "encoding/json" "fmt" + "log/slog" "net/http" "os" "strings" + "sync/atomic" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/zitadel/logging" ) var ( @@ -28,13 +32,31 @@ func main() { issuer := os.Getenv("ISSUER") port := os.Getenv("PORT") scopes := strings.Split(os.Getenv("SCOPES"), " ") + responseMode := os.Getenv("RESPONSE_MODE") redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath) cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + client := &http.Client{ + Timeout: time.Minute, + } + // enable outgoing request logging + logging.EnableHTTPClient(client, + logging.WithClientGroup("client"), + ) + options := []rp.Option{ rp.WithCookieHandler(cookieHandler), rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)), + rp.WithHTTPClient(client), + rp.WithLogger(logger), + rp.WithSigningAlgsFromDiscovery(), } if clientSecret == "" { options = append(options, rp.WithPKCE(cookieHandler)) @@ -43,7 +65,10 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, scopes, options...) + // One can add a logger to the context, + // pre-defining log attributes as required. + ctx := logging.ToContext(context.TODO(), logger) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } @@ -54,20 +79,37 @@ func main() { return uuid.New().String() } + urlOptions := []rp.URLParamOpt{ + rp.WithPromptURLParam("Welcome back!"), + } + + if responseMode != "" { + urlOptions = append(urlOptions, rp.WithResponseModeURLParam(oidc.ResponseMode(responseMode))) + } + // register the AuthURLHandler at your preferred path. // the AuthURLHandler creates the auth request and redirects the user to the auth server. // including state handling with secure cookie and the possibility to use PKCE. // Prompts can optionally be set to inform the server of // any messages that need to be prompted back to the user. - http.Handle("/login", rp.AuthURLHandler(state, provider, rp.WithPromptURLParam("Welcome back!"))) + http.Handle("/login", rp.AuthURLHandler( + state, + provider, + urlOptions..., + )) // for demonstration purposes the returned userinfo response is written as JSON object onto response marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + fmt.Println("access token", tokens.AccessToken) + fmt.Println("refresh token", tokens.RefreshToken) + fmt.Println("id token", tokens.IDToken) + data, err := json.Marshal(info) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + w.Header().Set("content-type", "application/json") w.Write(data) } @@ -118,8 +160,22 @@ func main() { // // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider)) + // simple counter for request IDs + var counter atomic.Int64 + // enable incomming request logging + mw := logging.Middleware( + logging.WithLogger(logger), + logging.WithGroup("server"), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + ) + lis := fmt.Sprintf("127.0.0.1:%s", port) - logrus.Infof("listening on http://%s/", lis) - logrus.Info("press ctrl+c to stop") - logrus.Fatal(http.ListenAndServe(lis, nil)) + logger.Info("server listening, press ctrl+c to stop", "addr", lis) + err = http.ListenAndServe(lis, mw(http.DefaultServeMux)) + if err != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) + } } diff --git a/example/client/device/device.go b/example/client/device/device.go index 284ba37..33bc570 100644 --- a/example/client/device/device.go +++ b/example/client/device/device.go @@ -1,3 +1,37 @@ +// Command device is an example Oauth2 Device Authorization Grant app. +// It creates a new Device Authorization request on the Issuer and then polls for tokens. +// The user is then prompted to visit a URL and enter the user code. +// Or, the complete URL can be used instead to omit manual entry. +// In practice then can be a "magic link" in the form or a QR. +// +// The following environment variables are used for configuration: +// +// ISSUER: URL to the OP, required. +// CLIENT_ID: ID of the application, required. +// CLIENT_SECRET: Secret to authenticate the app using basic auth. Only required if the OP expects this type of authentication. +// KEY_PATH: Path to a private key file, used to for JWT authentication of the App. Only required if the OP expects this type of authentication. +// SCOPES: Scopes of the Authentication Request. Optional. +// +// Basic usage: +// +// cd example/client/device +// export ISSUER="http://localhost:9000" CLIENT_ID="246048465824634593@demo" +// +// Get an Access Token: +// +// SCOPES="email profile" go run . +// +// Get an Access Token and ID Token: +// +// SCOPES="email profile openid" go run . +// +// Get an Access Token and Refresh Token +// +// SCOPES="email profile offline_access" go run . +// +// Get Access, Refresh and ID Tokens: +// +// SCOPES="email profile offline_access openid" go run . package main import ( @@ -11,8 +45,8 @@ import ( "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" ) var ( @@ -39,13 +73,13 @@ func main() { options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath))) } - provider, err := rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, "", scopes, options...) + provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...) if err != nil { logrus.Fatalf("error creating provider %s", err.Error()) } logrus.Info("starting device authorization flow") - resp, err := rp.DeviceAuthorization(scopes, provider) + resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil) if err != nil { logrus.Fatal(err) } @@ -57,5 +91,5 @@ func main() { if err != nil { logrus.Fatal(err) } - logrus.Infof("successfully obtained token: %v", token) + logrus.Infof("successfully obtained token: %#v", token) } diff --git a/example/client/github/github.go b/example/client/github/github.go index 9cb813c..f6c536b 100644 --- a/example/client/github/github.go +++ b/example/client/github/github.go @@ -10,10 +10,10 @@ import ( "golang.org/x/oauth2" githubOAuth "golang.org/x/oauth2/github" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rp/cli" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp/cli" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) var ( diff --git a/example/client/service/service.go b/example/client/service/service.go index 9526174..a88ab2f 100644 --- a/example/client/service/service.go +++ b/example/client/service/service.go @@ -13,7 +13,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/client/profile" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/profile" ) var client = http.DefaultClient @@ -25,7 +25,7 @@ func main() { scopes := strings.Split(os.Getenv("SCOPES"), " ") if keyPath != "" { - ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes) if err != nil { logrus.Fatalf("error creating token source %s", err.Error()) } @@ -76,7 +76,7 @@ func main() { http.Error(w, err.Error(), http.StatusInternalServerError) return } - ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(issuer, key, scopes) + ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -125,7 +125,7 @@ func main() { testURL := r.Form.Get("url") var data struct { URL string - Response interface{} + Response any } if testURL != "" { data.URL = testURL @@ -149,7 +149,7 @@ func main() { logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil)) } -func callExampleEndpoint(client *http.Client, testURL string) (interface{}, error) { +func callExampleEndpoint(client *http.Client, testURL string) (any, error) { req, err := http.NewRequest("GET", testURL, nil) if err != nil { return nil, err diff --git a/example/server/config/config.go b/example/server/config/config.go new file mode 100644 index 0000000..96837d4 --- /dev/null +++ b/example/server/config/config.go @@ -0,0 +1,40 @@ +package config + +import ( + "os" + "strings" +) + +const ( + // default port for the http server to run + DefaultIssuerPort = "9998" +) + +type Config struct { + Port string + RedirectURI []string + UsersFile string +} + +// FromEnvVars loads configuration parameters from environment variables. +// If there is no such variable defined, then use default values. +func FromEnvVars(defaults *Config) *Config { + if defaults == nil { + defaults = &Config{} + } + cfg := &Config{ + Port: defaults.Port, + RedirectURI: defaults.RedirectURI, + UsersFile: defaults.UsersFile, + } + if value, ok := os.LookupEnv("PORT"); ok { + cfg.Port = value + } + if value, ok := os.LookupEnv("USERS_FILE"); ok { + cfg.UsersFile = value + } + if value, ok := os.LookupEnv("REDIRECT_URI"); ok { + cfg.RedirectURI = strings.Split(value, ",") + } + return cfg +} diff --git a/example/server/config/config_test.go b/example/server/config/config_test.go new file mode 100644 index 0000000..3b73c0b --- /dev/null +++ b/example/server/config/config_test.go @@ -0,0 +1,77 @@ +package config + +import ( + "fmt" + "os" + "testing" +) + +func TestFromEnvVars(t *testing.T) { + + for _, tc := range []struct { + name string + env map[string]string + defaults *Config + want *Config + }{ + { + name: "no vars, no default values", + env: map[string]string{}, + want: &Config{}, + }, + { + name: "no vars, only defaults", + env: map[string]string{}, + defaults: &Config{ + Port: "6666", + UsersFile: "/default/user/path", + RedirectURI: []string{"re", "direct", "uris"}, + }, + want: &Config{ + Port: "6666", + UsersFile: "/default/user/path", + RedirectURI: []string{"re", "direct", "uris"}, + }, + }, + { + name: "overriding default values", + env: map[string]string{ + "PORT": "1234", + "USERS_FILE": "/path/to/users", + "REDIRECT_URI": "http://redirect/redirect", + }, + defaults: &Config{ + Port: "6666", + UsersFile: "/default/user/path", + RedirectURI: []string{"re", "direct", "uris"}, + }, + want: &Config{ + Port: "1234", + UsersFile: "/path/to/users", + RedirectURI: []string{"http://redirect/redirect"}, + }, + }, + { + name: "multiple redirect uris", + env: map[string]string{ + "REDIRECT_URI": "http://host_1,http://host_2,http://host_3", + }, + want: &Config{ + RedirectURI: []string{ + "http://host_1", "http://host_2", "http://host_3", + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + os.Clearenv() + for k, v := range tc.env { + os.Setenv(k, v) + } + cfg := FromEnvVars(tc.defaults) + if fmt.Sprint(cfg) != fmt.Sprint(tc.want) { + t.Errorf("Expected FromEnvVars()=%q, but got %q", tc.want, cfg) + } + }) + } +} diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go index e7c6e5f..05f0e34 100644 --- a/example/server/dynamic/login.go +++ b/example/server/dynamic/login.go @@ -6,9 +6,9 @@ import ( "html/template" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) const ( @@ -43,7 +43,7 @@ var ( type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } @@ -57,9 +57,9 @@ func NewLogin(authenticate authenticate, callback func(context.Context, string) } func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(issuerInterceptor.HandlerFunc(l.checkLoginHandler)) + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler) } type authenticate interface { diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go index 783c75c..2c00e41 100644 --- a/example/server/dynamic/op.go +++ b/example/server/dynamic/op.go @@ -7,11 +7,11 @@ import ( "log" "net/http" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) const ( @@ -47,7 +47,7 @@ func main() { //be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() //for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { @@ -76,7 +76,7 @@ func main() { //regardless of how many pages / steps there are in the process, the UI must be registered in the router, //so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) //is served on the correct path @@ -84,7 +84,7 @@ func main() { //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), //then you would have to set the path prefix (/custom/path/): //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler())) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", provider) server := &http.Server{ Addr: ":" + port, diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go index ae2e8f2..99505e4 100644 --- a/example/server/exampleop/device.go +++ b/example/server/exampleop/device.go @@ -1,21 +1,34 @@ package exampleop import ( + "context" "errors" "fmt" "io" "net/http" "net/url" - "github.com/gorilla/mux" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/go-chi/chi/v5" "github.com/gorilla/securecookie" "github.com/sirupsen/logrus" - "github.com/zitadel/oidc/v2/pkg/op" ) type deviceAuthenticate interface { CheckUsernamePasswordSimple(username, password string) error op.DeviceAuthorizationStorage + + // GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow, + // identified by the user code. + GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) + + // CompleteDeviceAuthorization marks a device authorization entry as Completed, + // identified by userCode. The Subject is added to the state, so that + // GetDeviceAuthorizatonState can use it to create a new Access Token. + CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error + + // DenyDeviceAuthorization marks a device authorization entry as Denied. + DenyDeviceAuthorization(ctx context.Context, userCode string) error } type deviceLogin struct { @@ -23,14 +36,14 @@ type deviceLogin struct { cookie *securecookie.SecureCookie } -func registerDeviceAuth(storage deviceAuthenticate, router *mux.Router) { +func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) { l := &deviceLogin{ storage: storage, cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil), } - router.HandleFunc("", l.userCodeHandler) - router.Path("/login").Methods(http.MethodPost).HandlerFunc(l.loginHandler) + router.HandleFunc("/", l.userCodeHandler) + router.Post("/login", l.loginHandler) router.HandleFunc("/confirm", l.confirmHandler) } diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go index c014c9a..77a6189 100644 --- a/example/server/exampleop/login.go +++ b/example/server/exampleop/login.go @@ -5,28 +5,29 @@ import ( "fmt" "net/http" - "github.com/gorilla/mux" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/go-chi/chi/v5" ) type login struct { authenticate authenticate - router *mux.Router + router chi.Router callback func(context.Context, string) string } -func NewLogin(authenticate authenticate, callback func(context.Context, string) string) *login { +func NewLogin(authenticate authenticate, callback func(context.Context, string) string, issuerInterceptor *op.IssuerInterceptor) *login { l := &login{ authenticate: authenticate, callback: callback, } - l.createRouter() + l.createRouter(issuerInterceptor) return l } -func (l *login) createRouter() { - l.router = mux.NewRouter() - l.router.Path("/username").Methods("GET").HandlerFunc(l.loginHandler) - l.router.Path("/username").Methods("POST").HandlerFunc(l.checkLoginHandler) +func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) { + l.router = chi.NewRouter() + l.router.Get("/username", l.loginHandler) + l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler)) } type authenticate interface { diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go index 5604483..e12c755 100644 --- a/example/server/exampleop/op.go +++ b/example/server/exampleop/op.go @@ -3,75 +3,83 @@ package exampleop import ( "crypto/sha256" "log" + "log/slog" "net/http" + "sync/atomic" "time" - "github.com/gorilla/mux" + "github.com/go-chi/chi/v5" + "github.com/zitadel/logging" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) const ( pathLoggedOut = "/logged-out" ) -func init() { - storage.RegisterClients( - storage.NativeClient("native"), - storage.WebClient("web", "secret"), - storage.WebClient("api", "secret"), - ) -} - type Storage interface { op.Storage authenticate deviceAuthenticate } +// simple counter for request IDs +var counter atomic.Int64 + // SetupServer creates an OIDC server with Issuer=http://localhost: // // Use one of the pre-made clients in storage/clients.go or register a new one. -func SetupServer(issuer string, storage Storage) *mux.Router { +func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router { // the OpenID Provider requires a 32-byte key for (token) encryption // be sure to create a proper crypto random key and manage it securely! key := sha256.Sum256([]byte("test")) - router := mux.NewRouter() + router := chi.NewRouter() + router.Use(logging.Middleware( + logging.WithLogger(logger), + logging.WithIDFunc(func() slog.Attr { + return slog.Int64("id", counter.Add(1)) + }), + )) // for simplicity, we provide a very small default page for users who have signed out router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) { - _, err := w.Write([]byte("signed out successfully")) - if err != nil { - log.Printf("error serving logged out page: %v", err) - } + w.Write([]byte("signed out successfully")) + // no need to check/log error, this will be handled by the middleware. }) // creation of the OpenIDProvider with the just created in-memory Storage - provider, err := newOP(storage, issuer, key) + provider, err := newOP(storage, issuer, key, logger, extraOptions...) if err != nil { log.Fatal(err) } - // the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process - // for the simplicity of the example this means a simple page with username and password field - l := NewLogin(storage, op.AuthCallbackURL(provider)) + //the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process + //for the simplicity of the example this means a simple page with username and password field + //be sure to provide an IssuerInterceptor with the IssuerFromRequest from the OP so the login can select / and pass it to the storage + l := NewLogin(storage, op.AuthCallbackURL(provider), op.NewIssuerInterceptor(provider.IssuerFromRequest)) // regardless of how many pages / steps there are in the process, the UI must be registered in the router, // so we will direct all calls to /login to the login UI - router.PathPrefix("/login/").Handler(http.StripPrefix("/login", l.router)) + router.Mount("/login/", http.StripPrefix("/login", l.router)) - router.PathPrefix("/device").Subrouter() - registerDeviceAuth(storage, router.PathPrefix("/device").Subrouter()) + router.Route("/device", func(r chi.Router) { + registerDeviceAuth(storage, r) + }) + + handler := http.Handler(provider) + if wrapServer { + handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints), op.AuthorizeCallbackHandler(provider)) + } // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration) // is served on the correct path // // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/), // then you would have to set the path prefix (/custom/path/) - router.PathPrefix("/").Handler(provider.HttpHandler()) + router.Mount("/", handler) return router } @@ -79,7 +87,7 @@ func SetupServer(issuer string, storage Storage) *mux.Router { // newOP will create an OpenID Provider for localhost on a specified port with a given encryption key // and a predefined default logout uri // it will enable all options (see descriptions) -func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, error) { +func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) { config := &op.Config{ CryptoKey: key, @@ -107,15 +115,19 @@ func newOP(storage op.Storage, issuer string, key [32]byte) (op.OpenIDProvider, DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: issuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } handler, err := op.NewOpenIDProvider(issuer, config, storage, - //we must explicitly allow the use of the http issuer - op.WithAllowInsecure(), - // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth - op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + append([]op.Option{ + //we must explicitly allow the use of the http issuer + op.WithAllowInsecure(), + // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth + op.WithCustomAuthEndpoint(op.NewEndpoint("auth")), + // Pass our logger to the OP + op.WithLogger(logger.WithGroup("op")), + }, extraOptions...)..., ) if err != nil { return nil, err diff --git a/example/server/main.go b/example/server/main.go index a2836ea..5bdbb05 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -2,34 +2,58 @@ package main import ( "fmt" - "log" + "log/slog" "net/http" + "os" - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/config" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/exampleop" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" ) +func getUserStore(cfg *config.Config) (storage.UserStore, error) { + if cfg.UsersFile == "" { + return storage.NewUserStore(fmt.Sprintf("http://localhost:%s/", cfg.Port)), nil + } + return storage.StoreFromFile(cfg.UsersFile) +} + func main() { - //we will run on :9998 - port := "9998" + cfg := config.FromEnvVars(&config.Config{Port: "9998"}) + logger := slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + ) + //which gives us the issuer: http://localhost:9998/ - issuer := fmt.Sprintf("http://localhost:%s/", port) + issuer := fmt.Sprintf("http://localhost:%s/", cfg.Port) + + storage.RegisterClients( + storage.NativeClient("native", cfg.RedirectURI...), + storage.WebClient("web", "secret", cfg.RedirectURI...), + storage.WebClient("api", "secret", cfg.RedirectURI...), + ) // the OpenIDProvider interface needs a Storage interface handling various checks and state manipulations // this might be the layer for accessing your database // in this example it will be handled in-memory - storage := storage.NewStorage(storage.NewUserStore(issuer)) - - router := exampleop.SetupServer(issuer, storage) + store, err := getUserStore(cfg) + if err != nil { + logger.Error("cannot create UserStore", "error", err) + os.Exit(1) + } + storage := storage.NewStorage(store) + router := exampleop.SetupServer(issuer, storage, logger, false) server := &http.Server{ - Addr: ":" + port, + Addr: ":" + cfg.Port, Handler: router, } - log.Printf("server listening on http://localhost:%s/", port) - log.Println("press ctrl+c to stop") - err := server.ListenAndServe() - if err != nil { - log.Fatal(err) + logger.Info("server listening, press ctrl+c to stop", "addr", issuer) + if server.ListenAndServe() != http.ErrServerClosed { + logger.Error("server terminated", "error", err) + os.Exit(1) } } diff --git a/example/server/storage/client.go b/example/server/storage/client.go index b8b9960..2b836c0 100644 --- a/example/server/storage/client.go +++ b/example/server/storage/client.go @@ -3,8 +3,8 @@ package storage import ( "time" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) var ( @@ -184,8 +184,26 @@ func WebClient(id, secret string, redirectURIs ...string) *Client { applicationType: op.ApplicationTypeWeb, authMethod: oidc.AuthMethodBasic, loginURL: defaultLoginURL, + responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode, oidc.ResponseTypeIDTokenOnly, oidc.ResponseTypeIDToken}, + grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange}, + accessTokenType: op.AccessTokenTypeBearer, + devMode: true, + idTokenUserinfoClaimsAssertion: false, + clockSkew: 0, + } +} + +// DeviceClient creates a device client with Basic authentication. +func DeviceClient(id, secret string) *Client { + return &Client{ + id: id, + secret: secret, + redirectURIs: nil, + applicationType: op.ApplicationTypeWeb, + authMethod: oidc.AuthMethodBasic, + loginURL: defaultLoginURL, responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, - grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, + grantTypes: []oidc.GrantType{oidc.GrantTypeDeviceCode}, accessTokenType: op.AccessTokenTypeBearer, devMode: false, idTokenUserinfoClaimsAssertion: false, diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go index f5412cf..9c7f544 100644 --- a/example/server/storage/oidc.go +++ b/example/server/storage/oidc.go @@ -1,12 +1,13 @@ package storage import ( + "log/slog" "time" "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) const ( @@ -34,6 +35,7 @@ type AuthRequest struct { UserID string Scopes []string ResponseType oidc.ResponseType + ResponseMode oidc.ResponseMode Nonce string CodeChallenge *OIDCCodeChallenge @@ -41,6 +43,19 @@ type AuthRequest struct { authTime time.Time } +// LogValue allows you to define which fields will be logged. +// Implements the [slog.LogValuer] +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.String("id", a.ID), + slog.Time("creation_date", a.CreationDate), + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("app_id", a.ApplicationID), + slog.String("callback_uri", a.CallbackURI), + ) +} + func (a *AuthRequest) GetID() string { return a.ID } @@ -86,7 +101,7 @@ func (a *AuthRequest) GetResponseType() oidc.ResponseType { } func (a *AuthRequest) GetResponseMode() oidc.ResponseMode { - return "" // we won't handle response mode in this example + return a.ResponseMode } func (a *AuthRequest) GetScopes() []string { @@ -106,7 +121,7 @@ func (a *AuthRequest) Done() bool { } func PromptToInternal(oidcPrompt oidc.SpaceDelimitedArray) []string { - prompts := make([]string, len(oidcPrompt)) + prompts := make([]string, 0, len(oidcPrompt)) for _, oidcPrompt := range oidcPrompt { switch oidcPrompt { case oidc.PromptNone, @@ -140,6 +155,7 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques UserID: userID, Scopes: authReq.Scopes, ResponseType: authReq.ResponseType, + ResponseMode: authReq.ResponseMode, Nonce: authReq.Nonce, CodeChallenge: &OIDCCodeChallenge{ Challenge: authReq.CodeChallenge, @@ -148,6 +164,15 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques } } +type AuthRequestWithSessionState struct { + *AuthRequest + SessionState string +} + +func (a *AuthRequestWithSessionState) GetSessionState() string { + return a.SessionState +} + type OIDCCodeChallenge struct { Challenge string Method string diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go index a4c4f46..d4315c6 100644 --- a/example/server/storage/storage.go +++ b/example/server/storage/storage.go @@ -11,11 +11,11 @@ import ( "sync" "time" + jose "github.com/go-jose/go-jose/v4" "github.com/google/uuid" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) // serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant @@ -28,8 +28,10 @@ var serviceKey1 = &rsa.PublicKey{ E: 65537, } -var _ op.Storage = &Storage{} -var _ op.ClientCredentialsStorage = &Storage{} +var ( + _ op.Storage = &Storage{} + _ op.ClientCredentialsStorage = &Storage{} +) // storage implements the op.Storage interface // typically you would implement this as a layer on top of your database @@ -59,7 +61,7 @@ func (s *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm { return s.algorithm } -func (s *signingKey) Key() interface{} { +func (s *signingKey) Key() any { return s.key } @@ -83,11 +85,15 @@ func (s *publicKey) Use() string { return "sig" } -func (s *publicKey) Key() interface{} { +func (s *publicKey) Key() any { return &s.key.PublicKey } func NewStorage(userStore UserStore) *Storage { + return NewStorageWithClients(userStore, clients) +} + +func NewStorageWithClients(userStore UserStore, clients map[string]*Client) *Storage { key, _ := rsa.GenerateKey(rand.Reader, 2048) return &Storage{ authRequests: make(map[string]*AuthRequest), @@ -145,6 +151,9 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error { // in this example we'll simply check the username / password and set a boolean to true // therefore we will also just check this boolean if the request / login has been finished request.done = true + + request.authTime = time.Now() + return nil } return fmt.Errorf("username or password wrong") @@ -167,6 +176,12 @@ func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthReque s.lock.Lock() defer s.lock.Unlock() + if len(authReq.Prompt) == 1 && authReq.Prompt[0] == "none" { + // With prompt=none, there is no way for the user to log in + // so return error right away. + return nil, oidc.ErrLoginRequired() + } + // typically, you'll fill your storage / storage model with the information of the passed object request := authRequestToInternal(authReq, userID) @@ -283,15 +298,19 @@ func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.T // if we get here, the currentRefreshToken was not empty, so the call is a refresh token request // we therefore will have to check the currentRefreshToken and renew the refresh token - refreshToken, refreshTokenID, err := s.renewRefreshToken(currentRefreshToken) + + newRefreshToken = uuid.NewString() + + accessToken, err := s.accessToken(applicationID, newRefreshToken, request.GetSubject(), request.GetAudience(), request.GetScopes()) if err != nil { return "", "", time.Time{}, err } - accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes()) - if err != nil { + + if err := s.renewRefreshToken(currentRefreshToken, newRefreshToken, accessToken.ID); err != nil { return "", "", time.Time{}, err } - return accessToken.ID, refreshToken, accessToken.Expiration, nil + + return accessToken.ID, newRefreshToken, accessToken.Expiration, nil } func (s *Storage) exchangeRefreshToken(ctx context.Context, request op.TokenExchangeRequest) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { @@ -373,14 +392,9 @@ func (s *Storage) RevokeToken(ctx context.Context, tokenIDOrToken string, userID if refreshToken.ApplicationID != clientID { return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") } - // if it is a refresh token, you will have to remove the access token as well delete(s.refreshTokens, refreshToken.ID) - for _, accessToken := range s.tokens { - if accessToken.RefreshTokenID == refreshToken.ID { - delete(s.tokens, accessToken.ID) - return nil - } - } + // if it is a refresh token, you will have to remove the access token as well + delete(s.tokens, refreshToken.AccessToken) return nil } @@ -476,6 +490,9 @@ func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserI // return err // } //} + if token.Expiration.Before(time.Now()) { + return fmt.Errorf("token is expired") + } return s.setUserinfo(ctx, userinfo, token.Subject, token.ApplicationID, token.Scopes) } @@ -517,11 +534,11 @@ func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection * // GetPrivateClaimsFromScopes implements the op.Storage interface // it will be called for the creation of a JWT access token to assert claims for custom scopes -func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { +func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) { return s.getPrivateClaimsFromScopes(ctx, userID, clientID, scopes) } -func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { +func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) { for _, scope := range scopes { switch scope { case CustomScope: @@ -582,33 +599,41 @@ func (s *Storage) createRefreshToken(accessToken *Token, amr []string, authTime Audience: accessToken.Audience, Expiration: time.Now().Add(5 * time.Hour), Scopes: accessToken.Scopes, + AccessToken: accessToken.ID, } s.refreshTokens[token.ID] = token return token.Token, nil } // renewRefreshToken checks the provided refresh_token and creates a new one based on the current -func (s *Storage) renewRefreshToken(currentRefreshToken string) (string, string, error) { +// +// [Refresh Token Rotation] is implemented. +// +// [Refresh Token Rotation]: https://www.rfc-editor.org/rfc/rfc6819#section-5.2.2.3 +func (s *Storage) renewRefreshToken(currentRefreshToken, newRefreshToken, newAccessToken string) error { s.lock.Lock() defer s.lock.Unlock() refreshToken, ok := s.refreshTokens[currentRefreshToken] if !ok { - return "", "", fmt.Errorf("invalid refresh token") + return fmt.Errorf("invalid refresh token") } - // deletes the refresh token and all access tokens which were issued based on this refresh token + // deletes the refresh token delete(s.refreshTokens, currentRefreshToken) - for _, token := range s.tokens { - if token.RefreshTokenID == currentRefreshToken { - delete(s.tokens, token.ID) - break - } + + // delete the access token which was issued based on this refresh token + delete(s.tokens, refreshToken.AccessToken) + + if refreshToken.Expiration.Before(time.Now()) { + return fmt.Errorf("expired refresh token") } + // creates a new refresh token based on the current one - token := uuid.NewString() - refreshToken.Token = token - refreshToken.ID = token - s.refreshTokens[token] = refreshToken - return token, refreshToken.ID, nil + refreshToken.Token = newRefreshToken + refreshToken.ID = newRefreshToken + refreshToken.Expiration = time.Now().Add(5 * time.Hour) + refreshToken.AccessToken = newAccessToken + s.refreshTokens[newRefreshToken] = refreshToken + return nil } // accessToken will store an access_token in-memory based on the provided information @@ -705,7 +730,7 @@ func (s *Storage) CreateTokenExchangeRequest(ctx context.Context, request op.Tok // GetPrivateClaimsFromScopesForTokenExchange implements the op.TokenExchangeStorage interface // it will be called for the creation of an exchanged JWT access token to assert claims for custom scopes // plus adding token exchange specific claims related to delegation or impersonation -func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}, err error) { +func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]any, err error) { claims, err = s.getPrivateClaimsFromScopes(ctx, "", request.GetClientID(), request.GetScopes()) if err != nil { return nil, err @@ -734,12 +759,12 @@ func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, useri return nil } -func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]interface{}) { +func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]any) { for _, scope := range request.GetScopes() { switch { case strings.HasPrefix(scope, CustomScopeImpersonatePrefix) && request.GetExchangeActor() == "": // Set actor subject claim for impersonation flow - claims = appendClaim(claims, "act", map[string]interface{}{ + claims = appendClaim(claims, "act", map[string]any{ "sub": request.GetExchangeSubject(), }) } @@ -747,7 +772,7 @@ func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenEx // Set actor subject claim for delegation flow // if request.GetExchangeActor() != "" { - // claims = appendClaim(claims, "act", map[string]interface{}{ + // claims = appendClaim(claims, "act", map[string]any{ // "sub": request.GetExchangeActor(), // }) // } @@ -769,16 +794,16 @@ func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Tim } // customClaim demonstrates how to return custom claims based on provided information -func customClaim(clientID string) map[string]interface{} { - return map[string]interface{}{ +func customClaim(clientID string) map[string]any { + return map[string]any{ "client": clientID, "other": "stuff", } } -func appendClaim(claims map[string]interface{}, claim string, value interface{}) map[string]interface{} { +func appendClaim(claims map[string]any, claim string, value any) map[string]any { if claims == nil { - claims = make(map[string]interface{}) + claims = make(map[string]any) } claims[claim] = value return claims diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go index 07af903..765d29a 100644 --- a/example/server/storage/storage_dynamic.go +++ b/example/server/storage/storage_dynamic.go @@ -4,10 +4,10 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) type multiStorage struct { @@ -239,7 +239,7 @@ func (s *multiStorage) SetIntrospectionFromToken(ctx context.Context, introspect // GetPrivateClaimsFromScopes implements the op.Storage interface // it will be called for the creation of a JWT access token to assert claims for custom scopes -func (s *multiStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { +func (s *multiStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) { storage, err := s.storageFromContext(ctx) if err != nil { return nil, err diff --git a/example/server/storage/token.go b/example/server/storage/token.go index ad907e3..beab38c 100644 --- a/example/server/storage/token.go +++ b/example/server/storage/token.go @@ -22,4 +22,5 @@ type RefreshToken struct { ApplicationID string Expiration time.Time Scopes []string + AccessToken string // Token.ID } diff --git a/example/server/storage/user.go b/example/server/storage/user.go index 173daef..ed8cdfa 100644 --- a/example/server/storage/user.go +++ b/example/server/storage/user.go @@ -2,6 +2,8 @@ package storage import ( "crypto/rsa" + "encoding/json" + "os" "strings" "golang.org/x/text/language" @@ -35,6 +37,18 @@ type userStore struct { users map[string]*User } +func StoreFromFile(path string) (UserStore, error) { + users := map[string]*User{} + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + if err := json.Unmarshal(data, &users); err != nil { + return nil, err + } + return userStore{users}, nil +} + func NewUserStore(issuer string) UserStore { hostname := strings.Split(strings.Split(issuer, "://")[1], ":")[0] return userStore{ diff --git a/example/server/storage/user_test.go b/example/server/storage/user_test.go new file mode 100644 index 0000000..c2e2212 --- /dev/null +++ b/example/server/storage/user_test.go @@ -0,0 +1,70 @@ +package storage + +import ( + "os" + "path" + "reflect" + "testing" + + "golang.org/x/text/language" +) + +func TestStoreFromFile(t *testing.T) { + for _, tc := range []struct { + name string + pathToFile string + content string + want UserStore + wantErr bool + }{ + { + name: "normal user file", + pathToFile: "userfile.json", + content: `{ + "id1": { + "ID": "id1", + "EmailVerified": true, + "PreferredLanguage": "DE" + } + }`, + want: userStore{map[string]*User{ + "id1": { + ID: "id1", + EmailVerified: true, + PreferredLanguage: language.German, + }, + }}, + }, + { + name: "malformed file", + pathToFile: "whatever", + content: "not a json just a text", + wantErr: true, + }, + { + name: "not existing file", + pathToFile: "what/ever/file", + wantErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + actualPath := path.Join(t.TempDir(), tc.pathToFile) + + if tc.content != "" && tc.pathToFile != "" { + if err := os.WriteFile(actualPath, []byte(tc.content), 0666); err != nil { + t.Fatalf("cannot create file with test content: %q", tc.content) + } + } + result, err := StoreFromFile(actualPath) + if err != nil && !tc.wantErr { + t.Errorf("StoreFromFile(%q) returned unexpected error %q", tc.pathToFile, err) + } else if err == nil && tc.wantErr { + t.Errorf("StoreFromFile(%q) did not return an expected error", tc.pathToFile) + } + if !tc.wantErr && !reflect.DeepEqual(tc.want, result.(userStore)) { + t.Errorf("expected StoreFromFile(%q) = %v, but got %v", + tc.pathToFile, tc.want, result) + } + }) + } +} diff --git a/go.mod b/go.mod index adb638e..a0f42c4 100644 --- a/go.mod +++ b/go.mod @@ -1,35 +1,40 @@ -module github.com/zitadel/oidc/v2 +module git.christmann.info/LARA/zitadel-oidc/v3 -go 1.18 +go 1.23.7 + +toolchain go1.24.1 require ( + github.com/bmatcuk/doublestar/v4 v4.8.1 + github.com/go-chi/chi/v5 v5.2.1 + github.com/go-jose/go-jose/v4 v4.0.5 github.com/golang/mock v1.6.0 github.com/google/go-github/v31 v31.0.0 - github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 - github.com/gorilla/schema v1.2.0 - github.com/gorilla/securecookie v1.1.1 - github.com/jeremija/gosubmit v0.2.7 + github.com/google/uuid v1.6.0 + github.com/gorilla/securecookie v1.1.2 + github.com/jeremija/gosubmit v0.2.8 github.com/muhlemmer/gu v0.3.1 - github.com/rs/cors v1.8.3 - github.com/sirupsen/logrus v1.9.0 - github.com/stretchr/testify v1.8.2 - golang.org/x/oauth2 v0.6.0 - golang.org/x/text v0.8.0 - gopkg.in/square/go-jose.v2 v2.6.0 + github.com/muhlemmer/httpforwarded v0.1.0 + github.com/rs/cors v1.11.1 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.10.0 + github.com/zitadel/logging v0.6.2 + github.com/zitadel/schema v1.3.1 + go.opentelemetry.io/otel v1.29.0 + golang.org/x/oauth2 v0.30.0 + golang.org/x/text v0.26.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/crypto v0.7.0 // indirect - golang.org/x/net v0.8.0 // indirect - golang.org/x/sys v0.6.0 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.29.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/sys v0.31.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4259674..4835505 100644 --- a/go.sum +++ b/go.sum @@ -1,70 +1,80 @@ +github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38= +github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= +github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= +github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github/v31 v31.0.0 h1:JJUxlP9lFK+ziXKimTCprajMApV1ecWD4NB6CCb0plo= github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= -github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= -github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA= +github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= +github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY= +github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= -github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU= +github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4= +github.com/zitadel/schema v1.3.1 h1:QT3kwiRIRXXLVAs6gCK/u044WmUVh6IlbLXUsn6yRQU= +github.com/zitadel/schema v1.3.1/go.mod h1:071u7D2LQacy1HAN+YnMd/mx1qVE2isb0Mjeqg46xnU= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= -golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -73,14 +83,13 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= @@ -89,17 +98,11 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.29.1 h1:7QBf+IK2gx70Ap/hDsOmam3GE0v9HicjfEdAxE62UoM= -google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= -gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/testutil/gen/gen.go b/internal/testutil/gen/gen.go index a9f5925..3e44b7d 100644 --- a/internal/testutil/gen/gen.go +++ b/internal/testutil/gen/gen.go @@ -8,8 +8,8 @@ import ( "fmt" "os" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) var custom = map[string]any{ diff --git a/internal/testutil/token.go b/internal/testutil/token.go index 121aa0b..72d08c5 100644 --- a/internal/testutil/token.go +++ b/internal/testutil/token.go @@ -8,8 +8,9 @@ import ( "errors" "time" - "github.com/zitadel/oidc/v2/pkg/oidc" - "gopkg.in/square/go-jose.v2" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + jose "github.com/go-jose/go-jose/v4" + "github.com/muhlemmer/gu" ) // KeySet implements oidc.Keys @@ -17,7 +18,7 @@ type KeySet struct{} // VerifySignature implments op.KeySet. func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { - if ctx.Err() != nil { + if err = ctx.Err(); err != nil { return nil, err } @@ -45,6 +46,16 @@ func init() { } } +type JWTProfileKeyStorage struct{} + +func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return gu.Ptr(WebKey.Public()), nil +} + func signEncodeTokenClaims(claims any) string { payload, err := json.Marshal(claims) if err != nil { @@ -106,6 +117,25 @@ func NewAccessToken(issuer, subject string, audience []string, expiration time.T return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil) } +func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) { + req := &oidc.JWTTokenRequest{ + Issuer: issuer, + Subject: clientID, + Audience: audience, + ExpiresAt: oidc.FromTime(expiration), + IssuedAt: oidc.FromTime(issuedAt), + } + // make sure the private claim map is set correctly + data, err := json.Marshal(req) + if err != nil { + panic(err) + } + if err = json.Unmarshal(data, req); err != nil { + panic(err) + } + return signEncodeTokenClaims(req), req +} + const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg` // These variables always result in a valid token @@ -137,6 +167,10 @@ func ValidAccessToken() (string, *oidc.AccessTokenClaims) { return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew) } +func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) { + return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration) +} + // ACRVerify is a oidc.ACRVerifier func. func ACRVerify(acr string) error { if acr != ValidACR { diff --git a/pkg/client/client.go b/pkg/client/client.go index 9eda973..2e1f536 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -11,32 +10,44 @@ import ( "strings" "time" + "github.com/go-jose/go-jose/v4" + "github.com/zitadel/logging" + "go.opentelemetry.io/otel" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/crypto" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) -var Encoder = httphelper.Encoder(oidc.NewEncoder()) +var ( + Encoder = httphelper.Encoder(oidc.NewEncoder()) + Tracer = otel.Tracer("github.com/zitadel/oidc/pkg/client") +) // Discover calls the discovery endpoint of the provided issuer and returns its configuration // It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url -func Discover(issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { +func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) { + ctx, span := Tracer.Start(ctx, "Discover") + defer span.End() + wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" { wellKnown = wellKnownUrl[0] } - req, err := http.NewRequest("GET", wellKnown, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil) if err != nil { return nil, err } discoveryConfig := new(oidc.DiscoveryConfiguration) err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) if err != nil { - return nil, err + return nil, errors.Join(oidc.ErrDiscoveryFailed, err) } + if logger, ok := logging.FromContext(ctx); ok { + logger.Debug("discover", "config", discoveryConfig) + } + if discoveryConfig.Issuer != issuer { return nil, oidc.ErrIssuerInvalid } @@ -48,12 +59,15 @@ type TokenEndpointCaller interface { HttpClient() *http.Client } -func CallTokenEndpoint(request interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - return callTokenEndpoint(request, nil, caller) +func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + return callTokenEndpoint(ctx, request, nil, caller) } -func callTokenEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) { + ctx, span := Tracer.Start(ctx, "callTokenEndpoint") + defer span.End() + + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -61,12 +75,18 @@ func callTokenEndpoint(request interface{}, authFn interface{}, caller TokenEndp if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil { return nil, err } - return &oauth2.Token{ + token := &oauth2.Token{ AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, RefreshToken: tokenRes.RefreshToken, Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second), - }, nil + } + if tokenRes.IDToken != "" { + token = token.WithExtra(map[string]any{ + "id_token": tokenRes.IDToken, + }) + } + return token, nil } type EndSessionCaller interface { @@ -74,8 +94,16 @@ type EndSessionCaller interface { HttpClient() *http.Client } -func CallEndSessionEndpoint(request interface{}, authFn interface{}, caller EndSessionCaller) (*url.URL, error) { - req, err := httphelper.FormRequest(caller.GetEndSessionEndpoint(), request, Encoder, authFn) +func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) { + ctx, span := Tracer.Start(ctx, "CallEndSessionEndpoint") + defer span.End() + + endpoint := caller.GetEndSessionEndpoint() + if endpoint == "" { + return nil, fmt.Errorf("end session %w", ErrEndpointNotSet) + } + + req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn) if err != nil { return nil, err } @@ -117,8 +145,16 @@ type RevokeRequest struct { ClientSecret string `schema:"client_secret"` } -func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCaller) error { - req, err := httphelper.FormRequest(caller.GetRevokeEndpoint(), request, Encoder, authFn) +func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error { + ctx, span := Tracer.Start(ctx, "CallRevokeEndpoint") + defer span.End() + + endpoint := caller.GetRevokeEndpoint() + if endpoint == "" { + return fmt.Errorf("revoke %w", ErrEndpointNotSet) + } + + req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn) if err != nil { return err } @@ -145,8 +181,11 @@ func CallRevokeEndpoint(request interface{}, authFn interface{}, caller RevokeCa return nil } -func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, authFn) +func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) { + ctx, span := Tracer.Start(ctx, "CallTokenExchangeEndpoint") + defer span.End() + + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn) if err != nil { return nil, err } @@ -158,12 +197,12 @@ func CallTokenExchangeEndpoint(request interface{}, authFn interface{}, caller T } func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) { - privateKey, err := crypto.BytesToPrivateKey(key) + privateKey, algorithm, err := crypto.BytesToPrivateKey(key) if err != nil { return nil, err } signingKey := jose.SigningKey{ - Algorithm: jose.RS256, + Algorithm: algorithm, Key: &jose.JSONWebKey{Key: privateKey, KeyID: keyID}, } return jose.NewSigner(signingKey, &jose.SignerOptions{}) @@ -186,8 +225,16 @@ type DeviceAuthorizationCaller interface { HttpClient() *http.Client } -func CallDeviceAuthorizationEndpoint(request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller) (*oidc.DeviceAuthorizationResponse, error) { - req, err := httphelper.FormRequest(caller.GetDeviceAuthorizationEndpoint(), request, Encoder, nil) +func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := Tracer.Start(ctx, "CallDeviceAuthorizationEndpoint") + defer span.End() + + endpoint := caller.GetDeviceAuthorizationEndpoint() + if endpoint == "" { + return nil, fmt.Errorf("device authorization %w", ErrEndpointNotSet) + } + + req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn) if err != nil { return nil, err } @@ -208,7 +255,10 @@ type DeviceAccessTokenRequest struct { } func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { - req, err := httphelper.FormRequest(caller.TokenEndpoint(), request, Encoder, nil) + ctx, span := Tracer.Start(ctx, "CallDeviceAccessTokenEndpoint") + defer span.End() + + req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil) if err != nil { return nil, err } @@ -216,28 +266,17 @@ func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTok req.SetBasicAuth(request.ClientID, request.ClientSecret) } - httpResp, err := caller.HttpClient().Do(req) - if err != nil { + resp := new(oidc.AccessTokenResponse) + if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil { return nil, err } - defer httpResp.Body.Close() - - resp := new(struct { - *oidc.AccessTokenResponse - *oidc.Error - }) - if err = json.NewDecoder(httpResp.Body).Decode(resp); err != nil { - return nil, err - } - - if httpResp.StatusCode == http.StatusOK { - return resp.AccessTokenResponse, nil - } - - return nil, resp.Error + return resp, nil } func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) { + ctx, span := Tracer.Start(ctx, "PollDeviceAccessTokenEndpoint") + defer span.End() + for { timer := time.After(interval) select { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go new file mode 100644 index 0000000..9e21e8e --- /dev/null +++ b/pkg/client/client_test.go @@ -0,0 +1,59 @@ +package client + +import ( + "context" + "net/http" + "testing" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDiscover(t *testing.T) { + type wantFields struct { + UILocalesSupported bool + } + + type args struct { + issuer string + wellKnownUrl []string + } + tests := []struct { + name string + args args + wantFields *wantFields + wantErr error + }{ + { + name: "spotify", // https://github.com/zitadel/oidc/issues/406 + args: args{ + issuer: "https://accounts.spotify.com", + }, + wantFields: &wantFields{ + UILocalesSupported: true, + }, + wantErr: nil, + }, + { + name: "discovery failed", + args: args{ + issuer: "https://example.com", + }, + wantErr: oidc.ErrDiscoveryFailed, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantFields == nil { + return + } + assert.Equal(t, tt.args.issuer, got.Issuer) + if tt.wantFields.UILocalesSupported { + assert.NotEmpty(t, got.UILocalesSupported) + } + }) + } +} diff --git a/pkg/client/errors.go b/pkg/client/errors.go new file mode 100644 index 0000000..47210e5 --- /dev/null +++ b/pkg/client/errors.go @@ -0,0 +1,5 @@ +package client + +import "errors" + +var ErrEndpointNotSet = errors.New("endpoint not set") diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index e19a720..86a9ab7 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -2,32 +2,65 @@ package client_test import ( "bytes" + "context" + "fmt" "io" - "io/ioutil" + "log/slog" "math/rand" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "os" + "os/signal" "strconv" + "syscall" "testing" "time" "github.com/jeremija/gosubmit" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/example/server/exampleop" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/client/rs" - "github.com/zitadel/oidc/v2/pkg/client/tokenexchange" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/exampleop" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/tokenexchange" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) +var Logger = slog.New( + slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), +) + +var CTX context.Context + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT) + defer cancel() + CTX, cancel = context.WithTimeout(ctx, time.Minute) + defer cancel() + return m.Run() + }()) +} + func TestRelyingPartySession(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testRelyingPartySession(t, wrapServer) + }) + } +} + +func testRelyingPartySession(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -35,17 +68,17 @@ func TestRelyingPartySession(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "") + newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -53,10 +86,13 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("new token type %s", newTokens.TokenType) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) require.NotEmpty(t, newTokens.AccessToken, "new accessToken") + assert.NotEmpty(t, newTokens.IDToken, "new idToken") + assert.NotNil(t, newTokens.IDTokenClaims) + assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject) t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -65,17 +101,111 @@ func TestRelyingPartySession(t *testing.T) { } t.Log("------ attempt refresh again (should fail) ------") - t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(provider, refreshToken, "", "") + t.Log("trying original refresh token", tokens.RefreshToken) + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } +func TestRelyingPartyWithSigningAlgsFromDiscovery(t *testing.T) { + targetURL := "http://local-site" + localURL, err := url.Parse(targetURL + "/login?requestID=1234") + require.NoError(t, err, "local url") + + t.Log("------- start example OP ------") + seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) + clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) + clientSecret := "secret" + client := storage.WebClient(clientID, clientSecret, targetURL) + storage.RegisterClients(client) + exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) + var dh deferredHandler + opServer := httptest.NewServer(&dh) + defer opServer.Close() + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true) + + t.Log("------- create RP ------") + provider, err := rp.NewRelyingPartyOIDC( + CTX, + opServer.URL, + clientID, + clientSecret, + targetURL, + []string{"openid"}, + rp.WithSigningAlgsFromDiscovery(), + ) + require.NoError(t, err, "new rp") + + t.Log("------- run authorization code flow ------") + jar, err := cookiejar.New(nil) + require.NoError(t, err, "create cookie jar") + httpClient := &http.Client{ + Timeout: time.Second * 5, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + Jar: jar, + } + state := "state-" + strconv.FormatInt(seed.Int63(), 25) + capturedW := httptest.NewRecorder() + get := httptest.NewRequest("GET", localURL.String(), nil) + rp.AuthURLHandler(func() string { return state }, provider, + rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"), + rp.WithURLParam("custom", "param"), + )(capturedW, get) + defer func() { + if t.Failed() { + t.Log("response body (redirect from RP to OP)", capturedW.Body.String()) + } + }() + resp := capturedW.Result() + startAuthURL, err := resp.Location() + require.NoError(t, err, "get redirect") + loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL) + form := getForm(t, "get login form", httpClient, loginPageURL) + defer func() { + if t.Failed() { + t.Logf("login form (unfilled): %s", string(form)) + } + }() + postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL, + gosubmit.Set("username", "test-user@local-site"), + gosubmit.Set("password", "verysecure"), + ) + codeBearingURL := getRedirect(t, "get redirect with code", httpClient, postLoginRedirectURL) + capturedW = httptest.NewRecorder() + get = httptest.NewRequest("GET", codeBearingURL.String(), nil) + var idToken string + redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + idToken = newTokens.IDToken + http.Redirect(w, r, targetURL, http.StatusFound) + } + rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get) + defer func() { + if t.Failed() { + t.Log("token exchange response body", capturedW.Body.String()) + require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code") + } + }() + + t.Log("------- verify id token ------") + _, err = rp.VerifyIDToken[*oidc.IDTokenClaims](CTX, idToken, provider.IDTokenVerifier()) + require.NoError(t, err, "verify id token") +} + func TestResourceServerTokenExchange(t *testing.T) { + for _, wrapServer := range []bool{false, true} { + t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) { + testResourceServerTokenExchange(t, wrapServer) + }) + } +} + +func testResourceServerTokenExchange(t *testing.T, wrapServer bool) { t.Log("------- start example OP ------") targetURL := "http://local-site" exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) @@ -83,23 +213,24 @@ func TestResourceServerTokenExchange(t *testing.T) { opServer := httptest.NewServer(&dh) defer opServer.Close() t.Logf("auth server at %s", opServer.URL) - dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer) seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) clientSecret := "secret" t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) - resourceServer, err := rs.NewResourceServerClientCredentials(opServer.URL, clientID, clientSecret) + resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") t.Log("------- exchage refresh tokens (impersonation) ------") tokenExchangeResponse, err := tokenexchange.ExchangeToken( + CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -117,7 +248,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -128,8 +259,9 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------- attempt exchage again (should fail) ------") tokenExchangeResponse, err = tokenexchange.ExchangeToken( + CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -141,10 +273,9 @@ func TestResourceServerTokenExchange(t *testing.T) { require.Error(t, err, "refresh token") assert.Contains(t, err.Error(), "subject_token is invalid") require.Nil(t, tokenExchangeResponse, "token exchange response") - } -func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") @@ -166,12 +297,14 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, key := []byte("test1234test1234") cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) provider, err = rp.NewRelyingPartyOIDC( + CTX, opServer.URL, clientID, clientSecret, targetURL, []string{"openid", "email", "profile", "offline_access"}, rp.WithPKCE(cookieHandler), + rp.WithAuthStyle(oauth2.AuthStyleInHeader), rp.WithVerifierOpts( rp.WithIssuedAtOffset(5*time.Second), rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"), @@ -240,7 +373,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } var email string - redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + tokens = newTokens require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") t.Log("access token", tokens.AccessToken) @@ -248,9 +382,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, t.Log("id token", tokens.IDToken) t.Log("email", info.Email) - accessToken = tokens.AccessToken - refreshToken = tokens.RefreshToken - idToken = tokens.IDToken email = info.Email http.Redirect(w, r, targetURL, 302) } @@ -272,12 +403,124 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, require.NoError(t, err, "get fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") - require.NotEmpty(t, idToken, "id token") - assert.NotEmpty(t, refreshToken, "refresh token") - assert.NotEmpty(t, accessToken, "access token") + require.NotEmpty(t, tokens.IDToken, "id token") + assert.NotEmpty(t, tokens.RefreshToken, "refresh token") + assert.NotEmpty(t, tokens.AccessToken, "access token") assert.NotEmpty(t, email, "email") - return provider, accessToken, refreshToken, idToken + return provider, tokens +} + +func TestClientCredentials(t *testing.T) { + targetURL := "http://local-site" + exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) + var dh deferredHandler + opServer := httptest.NewServer(&dh) + defer opServer.Close() + t.Logf("auth server at %s", opServer.URL) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true) + + provider, err := rp.NewRelyingPartyOIDC( + CTX, + opServer.URL, + "sid1", + "verysecret", + targetURL, + []string{"openid"}, + ) + require.NoError(t, err, "new rp") + + token, err := rp.ClientCredentials(CTX, provider, nil) + require.NoError(t, err, "ClientCredentials call") + require.NotNil(t, token) + assert.NotEmpty(t, token.AccessToken) +} + +func TestErrorFromPromptNone(t *testing.T) { + jar, err := cookiejar.New(nil) + require.NoError(t, err, "create cookie jar") + httpClient := &http.Client{ + Timeout: time.Second * 5, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + Jar: jar, + } + + t.Log("------- start example OP ------") + targetURL := "http://local-site" + exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL)) + var dh deferredHandler + opServer := httptest.NewServer(&dh) + defer opServer.Close() + t.Logf("auth server at %s", opServer.URL) + dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors( + func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("request to %s", r.URL) + next.ServeHTTP(w, r) + }) + }, + )) + seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano())) + clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) + clientSecret := "secret" + client := storage.WebClient(clientID, clientSecret, targetURL) + storage.RegisterClients(client) + + t.Log("------- create RP ------") + key := []byte("test1234test1234") + cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure()) + provider, err := rp.NewRelyingPartyOIDC( + CTX, + opServer.URL, + clientID, + clientSecret, + targetURL, + []string{"openid", "email", "profile", "offline_access"}, + rp.WithPKCE(cookieHandler), + rp.WithVerifierOpts( + rp.WithIssuedAtOffset(5*time.Second), + rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"), + ), + ) + require.NoError(t, err, "new rp") + + t.Log("------- start auth flow with prompt=none ------- ") + state := "state-32892" + capturedW := httptest.NewRecorder() + localURL, err := url.Parse(targetURL + "/login") + require.NoError(t, err) + + get := httptest.NewRequest("GET", localURL.String(), nil) + rp.AuthURLHandler(func() string { return state }, provider, + rp.WithPromptURLParam("none"), + rp.WithResponseModeURLParam(oidc.ResponseModeFragment), + )(capturedW, get) + + defer func() { + if t.Failed() { + t.Log("response body (redirect from RP to OP)", capturedW.Body.String()) + } + }() + require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code") + require.Less(t, capturedW.Code, 400, "captured response code") + + //nolint:bodyclose + resp := capturedW.Result() + jar.SetCookies(localURL, resp.Cookies()) + + startAuthURL, err := resp.Location() + require.NoError(t, err, "get redirect") + assert.NotEmpty(t, startAuthURL, "login url") + t.Log("Starting auth at", startAuthURL) + + t.Log("------- get redirect from OP ------") + loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL) + t.Log("login page URL", loginPageURL) + + require.Contains(t, loginPageURL.String(), `error=login_required`, "prompt=none should error") + require.Contains(t, loginPageURL.String(), `local-site#error=`, "response_mode=fragment means '#' instead of '?'") } type deferredHandler struct { @@ -325,7 +568,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [ func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { // TODO: switch to io.NopCloser when go1.15 support is dropped - req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( + req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., ) if req.URL.Scheme == "" { diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go index 1686de6..98a54fd 100644 --- a/pkg/client/jwt_profile.go +++ b/pkg/client/jwt_profile.go @@ -1,17 +1,18 @@ package client import ( + "context" "net/url" "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) // JWTProfileExchange handles the oauth2 jwt profile exchange -func JWTProfileExchange(jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { - return CallTokenEndpoint(jwtProfileGrantRequest, caller) +func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) { + return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller) } func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption { diff --git a/pkg/client/key.go b/pkg/client/key.go index 740c6d3..7f38311 100644 --- a/pkg/client/key.go +++ b/pkg/client/key.go @@ -2,7 +2,7 @@ package client import ( "encoding/json" - "io/ioutil" + "os" ) const ( @@ -10,7 +10,7 @@ const ( applicationKey = "application" ) -type keyFile struct { +type KeyFile struct { Type string `json:"type"` // serviceaccount or application KeyID string `json:"keyId"` Key string `json:"key"` @@ -23,16 +23,16 @@ type keyFile struct { ClientID string `json:"clientId"` } -func ConfigFromKeyFile(path string) (*keyFile, error) { - data, err := ioutil.ReadFile(path) +func ConfigFromKeyFile(path string) (*KeyFile, error) { + data, err := os.ReadFile(path) if err != nil { return nil, err } return ConfigFromKeyFileData(data) } -func ConfigFromKeyFileData(data []byte) (*keyFile, error) { - var f keyFile +func ConfigFromKeyFileData(data []byte) (*KeyFile, error) { + var f KeyFile if err := json.Unmarshal(data, &f); err != nil { return nil, err } diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go index a934f7d..fb351f0 100644 --- a/pkg/client/profile/jwt_profile.go +++ b/pkg/client/profile/jwt_profile.go @@ -1,19 +1,25 @@ package profile import ( + "context" "net/http" "time" + jose "github.com/go-jose/go-jose/v4" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) +type TokenSource interface { + oauth2.TokenSource + TokenCtx(context.Context) (*oauth2.Token, error) +} + // jwtProfileTokenSource implement the oauth2.TokenSource // it will request a token using the OAuth2 JWT Profile Grant -// therefore sending an `assertion` by singing a JWT with the provided private key +// therefore sending an `assertion` by signing a JWT with the provided private key type jwtProfileTokenSource struct { clientID string audience []string @@ -23,23 +29,38 @@ type jwtProfileTokenSource struct { tokenEndpoint string } -func NewJWTProfileTokenSourceFromKeyFile(issuer, keyPath string, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFile(keyPath) +// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFile(jsonFile) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSourceFromKeyFileData(issuer string, data []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { - keyData, err := client.ConfigFromKeyFileData(data) +// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { + keyData, err := client.ConfigFromKeyFileData(jsonData) if err != nil { return nil, err } - return NewJWTProfileTokenSource(issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) + return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...) } -func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (oauth2.TokenSource, error) { +// NewJWTProfileSource returns an implementation of oauth2.TokenSource +// It will request a token using the OAuth2 JWT Profile Grant, +// therefore sending an `assertion` by singing a JWT with the provided private key. +// +// The passed context is only used for the call to the Discover endpoint. +func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err @@ -55,7 +76,7 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes opt(source) } if source.tokenEndpoint == "" { - config, err := client.Discover(issuer, source.httpClient) + config, err := client.Discover(ctx, issuer, source.httpClient) if err != nil { return nil, err } @@ -64,13 +85,13 @@ func NewJWTProfileTokenSource(issuer, clientID, keyID string, key []byte, scopes return source, nil } -func WithHTTPClient(client *http.Client) func(*jwtProfileTokenSource) { +func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.httpClient = client } } -func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*jwtProfileTokenSource) { +func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) { return func(source *jwtProfileTokenSource) { source.tokenEndpoint = tokenEndpoint } @@ -85,9 +106,13 @@ func (j *jwtProfileTokenSource) HttpClient() *http.Client { } func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) { + return j.TokenCtx(context.Background()) +} + +func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) { assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer) if err != nil { return nil, err } - return client.JWTProfileExchange(oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) + return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j) } diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go index 91b200d..10edaa7 100644 --- a/pkg/client/rp/cli/cli.go +++ b/pkg/client/rp/cli/cli.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "github.com/zitadel/oidc/v2/pkg/client/rp" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) const ( diff --git a/pkg/client/rp/delegation.go b/pkg/client/rp/delegation.go index b16a39e..fb4fc63 100644 --- a/pkg/client/rp/delegation.go +++ b/pkg/client/rp/delegation.go @@ -1,7 +1,7 @@ package rp import ( - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc/grants/tokenexchange" ) // DelegationTokenRequest is an implementation of TokenExchangeRequest diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go index 73b67ca..1fadd56 100644 --- a/pkg/client/rp/device.go +++ b/pkg/client/rp/device.go @@ -5,14 +5,13 @@ import ( "fmt" "time" - "github.com/zitadel/oidc/v2/pkg/client" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) { confg := rp.OAuthConfig() req := &oidc.ClientCredentialsRequest{ - GrantType: oidc.GrantTypeDeviceCode, Scope: scopes, ClientID: confg.ClientID, ClientSecret: confg.ClientSecret, @@ -33,19 +32,27 @@ func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc. // DeviceAuthorization starts a new Device Authorization flow as defined // in RFC 8628, section 3.1 and 3.2: // https://www.rfc-editor.org/rfc/rfc8628#section-3.1 -func DeviceAuthorization(scopes []string, rp RelyingParty) (*oidc.DeviceAuthorizationResponse, error) { +func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := client.Tracer.Start(ctx, "DeviceAuthorization") + defer span.End() + + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization") req, err := newDeviceClientCredentialsRequest(scopes, rp) if err != nil { return nil, err } - return client.CallDeviceAuthorizationEndpoint(req, rp) + return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn) } // DeviceAccessToken attempts to obtain tokens from a Device Authorization, // by means of polling as defined in RFC, section 3.3 and 3.4: // https://www.rfc-editor.org/rfc/rfc8628#section-3.4 func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) { + ctx, span := client.Tracer.Start(ctx, "DeviceAccessToken") + defer span.End() + + ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken") req := &client.DeviceAccessTokenRequest{ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{ GrantType: oidc.GrantTypeDeviceCode, diff --git a/pkg/client/rp/errors.go b/pkg/client/rp/errors.go new file mode 100644 index 0000000..b95420b --- /dev/null +++ b/pkg/client/rp/errors.go @@ -0,0 +1,5 @@ +package rp + +import "errors" + +var ErrRelyingPartyNotSupportRevokeCaller = errors.New("RelyingParty does not support RevokeCaller") diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go index 3438bd6..0ccbad2 100644 --- a/pkg/client/rp/jwks.go +++ b/pkg/client/rp/jwks.go @@ -7,10 +7,11 @@ import ( "net/http" "sync" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet { @@ -83,6 +84,9 @@ func (i *inflight) result() ([]jose.JSONWebKey, error) { } func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + ctx, span := client.Tracer.Start(ctx, "VerifySignature") + defer span.End() + keyID, alg := oidc.GetKeyIDAndAlg(jws) if alg == "" { alg = r.defaultAlg @@ -135,6 +139,9 @@ func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool { } func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) { + ctx, span := client.Tracer.Start(ctx, "verifySignatureRemote") + defer span.End() + keys, err := r.keysFromRemote(ctx) if err != nil { return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err) @@ -159,6 +166,9 @@ func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { // keysFromRemote syncs the key set from the remote set, records the values in the // cache, and returns the key set. func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { + ctx, span := client.Tracer.Start(ctx, "keysFromRemote") + defer span.End() + // Need to lock to inspect the inflight request field. r.mu.Lock() // If there's not a current inflight request, create one. @@ -182,6 +192,9 @@ func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e } func (r *remoteKeySet) updateKeys(ctx context.Context) { + ctx, span := client.Tracer.Start(ctx, "updateKeys") + defer span.End() + // Sync keys and finish inflight when that's done. keys, err := r.fetchRemoteKeys(ctx) @@ -201,7 +214,10 @@ func (r *remoteKeySet) updateKeys(ctx context.Context) { } func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) { - req, err := http.NewRequest("GET", r.jwksURL, nil) + ctx, span := client.Tracer.Start(ctx, "fetchRemoteKeys") + defer span.End() + + req, err := http.NewRequestWithContext(ctx, "GET", r.jwksURL, nil) if err != nil { return nil, fmt.Errorf("oidc: can't create request: %v", err) } diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go new file mode 100644 index 0000000..556220c --- /dev/null +++ b/pkg/client/rp/log.go @@ -0,0 +1,17 @@ +package rp + +import ( + "context" + "log/slog" + + "github.com/zitadel/logging" +) + +func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context { + logger, ok := rp.Logger(ctx) + if !ok { + return ctx + } + logger = logger.With(slog.Group("rp", attrs...)) + return logging.ToContext(ctx, logger) +} diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index ede7453..c2759a2 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -4,19 +4,20 @@ import ( "context" "encoding/base64" "errors" - "fmt" + "log/slog" "net/http" "net/url" - "strings" "time" + "github.com/go-jose/go-jose/v4" "github.com/google/uuid" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" + "golang.org/x/oauth2/clientcredentials" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/zitadel/logging" ) const ( @@ -59,38 +60,55 @@ type RelyingParty interface { // UserinfoEndpoint returns the userinfo UserinfoEndpoint() string - // GetDeviceAuthorizationEndpoint returns the enpoint which can + // GetDeviceAuthorizationEndpoint returns the endpoint which can // be used to start a DeviceAuthorization flow. GetDeviceAuthorizationEndpoint() string - // IDTokenVerifier returns the verifier interface used for oidc id_token verification - IDTokenVerifier() IDTokenVerifier - // ErrorHandler returns the handler used for callback errors + // IDTokenVerifier returns the verifier used for oidc id_token verification + IDTokenVerifier() *IDTokenVerifier + // ErrorHandler returns the handler used for callback errors ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) + + // Logger from the context, or a fallback if set. + Logger(context.Context) (logger *slog.Logger, ok bool) +} + +type HasUnauthorizedHandler interface { + // UnauthorizedHandler returns the handler used for unauthorized errors + UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string) } type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) +type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string) var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) } +var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) { + http.Error(w, desc, http.StatusUnauthorized) +} type relyingParty struct { - issuer string - DiscoveryEndpoint string - endpoints Endpoints - oauthConfig *oauth2.Config - oauth2Only bool - pkce bool + issuer string + DiscoveryEndpoint string + endpoints Endpoints + oauthConfig *oauth2.Config + oauth2Only bool + pkce bool + useSigningAlgsFromDiscovery bool httpClient *http.Client cookieHandler *httphelper.CookieHandler - errorHandler func(http.ResponseWriter, *http.Request, string, string, string) - idTokenVerifier IDTokenVerifier - verifierOpts []VerifierOption - signer jose.Signer + oauthAuthStyle oauth2.AuthStyle + + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string) + idTokenVerifier *IDTokenVerifier + verifierOpts []VerifierOption + signer jose.Signer + logger *slog.Logger } func (rp *relyingParty) OAuthConfig() *oauth2.Config { @@ -137,7 +155,7 @@ func (rp *relyingParty) GetRevokeEndpoint() string { return rp.endpoints.RevokeURL } -func (rp *relyingParty) IDTokenVerifier() IDTokenVerifier { +func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier { if rp.idTokenVerifier == nil { rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...) } @@ -151,14 +169,31 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, return rp.errorHandler } +func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) { + if rp.unauthorizedHandler == nil { + rp.unauthorizedHandler = DefaultUnauthorizedHandler + } + return rp.unauthorizedHandler +} + +func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) { + logger, ok = logging.FromContext(ctx) + if ok { + return logger, ok + } + return rp.logger, rp.logger != nil +} + // NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given // OAuth2 Config and possible configOptions // it will use the AuthURL and TokenURL set in config func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) { rp := &relyingParty{ - oauthConfig: config, - httpClient: httphelper.DefaultHTTPClient, - oauth2Only: true, + oauthConfig: config, + httpClient: httphelper.DefaultHTTPClient, + oauth2Only: true, + unauthorizedHandler: DefaultUnauthorizedHandler, + oauthAuthStyle: oauth2.AuthStyleAutoDetect, } for _, optFunc := range options { @@ -167,9 +202,12 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart } } + rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle + // avoid races by calling these early - _ = rp.IDTokenVerifier() // sets idTokenVerifier - _ = rp.ErrorHandler() // sets errorHandler + _ = rp.IDTokenVerifier() // sets idTokenVerifier + _ = rp.ErrorHandler() // sets errorHandler + _ = rp.UnauthorizedHandler() // sets unauthorizedHandler return rp, nil } @@ -177,7 +215,7 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart // NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given // issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions // it will run discovery on the provided issuer and use the found endpoints -func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { +func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) { rp := &relyingParty{ issuer: issuer, oauthConfig: &oauth2.Config{ @@ -186,8 +224,9 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco RedirectURL: redirectURI, Scopes: scopes, }, - httpClient: httphelper.DefaultHTTPClient, - oauth2Only: false, + httpClient: httphelper.DefaultHTTPClient, + oauth2Only: false, + oauthAuthStyle: oauth2.AuthStyleAutoDetect, } for _, optFunc := range options { @@ -195,17 +234,25 @@ func NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI string, sco return nil, err } } - discoveryConfiguration, err := client.Discover(rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) + ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC") + discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint) if err != nil { return nil, err } + if rp.useSigningAlgsFromDiscovery { + rp.verifierOpts = append(rp.verifierOpts, WithSupportedSigningAlgorithms(discoveryConfiguration.IDTokenSigningAlgValuesSupported...)) + } endpoints := GetEndpoints(discoveryConfiguration) rp.oauthConfig.Endpoint = endpoints.Endpoint rp.endpoints = endpoints + rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle + rp.endpoints.Endpoint.AuthStyle = rp.oauthAuthStyle + // avoid races by calling these early - _ = rp.IDTokenVerifier() // sets idTokenVerifier - _ = rp.ErrorHandler() // sets errorHandler + _ = rp.IDTokenVerifier() // sets idTokenVerifier + _ = rp.ErrorHandler() // sets errorHandler + _ = rp.UnauthorizedHandler() // sets unauthorizedHandler return rp, nil } @@ -254,6 +301,20 @@ func WithErrorHandler(errorHandler ErrorHandler) Option { } } +func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option { + return func(rp *relyingParty) error { + rp.unauthorizedHandler = unauthorizedHandler + return nil + } +} + +func WithAuthStyle(oauthAuthStyle oauth2.AuthStyle) Option { + return func(rp *relyingParty) error { + rp.oauthAuthStyle = oauthAuthStyle + return nil + } +} + func WithVerifierOpts(opts ...VerifierOption) Option { return func(rp *relyingParty) error { rp.verifierOpts = opts @@ -282,6 +343,24 @@ func WithJWTProfile(signerFromKey SignerFromKey) Option { } } +// WithLogger sets a logger that is used +// in case the request context does not contain a logger. +func WithLogger(logger *slog.Logger) Option { + return func(rp *relyingParty) error { + rp.logger = logger + return nil + } +} + +// WithSigningAlgsFromDiscovery appends the [WithSupportedSigningAlgorithms] option to the Verifier Options. +// The algorithms returned in the `id_token_signing_alg_values_supported` from the discovery response will be set. +func WithSigningAlgsFromDiscovery() Option { + return func(rp *relyingParty) error { + rp.useSigningAlgsFromDiscovery = true + return nil + } +} + type SignerFromKey func() (jose.Signer, error) func SignerFromKeyPath(path string) SignerFromKey { @@ -310,26 +389,6 @@ func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey { } } -// Discover calls the discovery endpoint of the provided issuer and returns the found endpoints -// -// deprecated: use client.Discover -func Discover(issuer string, httpClient *http.Client) (Endpoints, error) { - wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint - req, err := http.NewRequest("GET", wellKnown, nil) - if err != nil { - return Endpoints{}, err - } - discoveryConfig := new(oidc.DiscoveryConfiguration) - err = httphelper.HttpRequest(httpClient, req, &discoveryConfig) - if err != nil { - return Endpoints{}, err - } - if discoveryConfig.Issuer != issuer { - return Endpoints{}, oidc.ErrIssuerInvalid - } - return GetEndpoints(discoveryConfig), nil -} - // AuthURL returns the auth request url // (wrapping the oauth2 `AuthCodeURL`) func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { @@ -342,7 +401,7 @@ func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string { // AuthURLHandler extends the `AuthURL` method with a http redirect handler // including handling setting cookie for secure `state` transfer. -// Custom paramaters can optionally be set to the redirect URL. +// Custom parameters can optionally be set to the redirect URL. func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { opts := make([]AuthURLOpt, len(urlParam)) @@ -352,13 +411,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam state := stateFn() if err := trySetStateCookie(w, state, rp); err != nil { - http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp) return } if rp.IsPKCE() { codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp) if err != nil { - http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp) return } opts = append(opts, WithCodeChallenge(codeChallenge)) @@ -377,35 +436,73 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri return oidc.NewSHACodeChallenge(codeVerifier), nil } +// ErrMissingIDToken is returned when an id_token was expected, +// but not received in the token response. +var ErrMissingIDToken = errors.New("id_token missing") + +func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) { + ctx, span := client.Tracer.Start(ctx, "verifyTokenResponse") + defer span.End() + + if rp.IsOAuth2Only() { + return &oidc.Tokens[C]{Token: token}, nil + } + idTokenString, ok := token.Extra(idTokenKey).(string) + if !ok { + return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken + } + idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) + if err != nil { + return nil, err + } + return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil +} + // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { + ctx, codeExchangeSpan := client.Tracer.Start(ctx, "CodeExchange") + defer codeExchangeSpan.End() + + ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange") ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) codeOpts := make([]oauth2.AuthCodeOption, 0) for _, opt := range opts { codeOpts = append(codeOpts, opt()...) } + ctx, oauthExchangeSpan := client.Tracer.Start(ctx, "OAuthExchange") token, err := rp.OAuthConfig().Exchange(ctx, code, codeOpts...) if err != nil { return nil, err } + oauthExchangeSpan.End() + return verifyTokenResponse[C](ctx, token, rp) +} - if rp.IsOAuth2Only() { - return &oidc.Tokens[C]{Token: token}, nil +// ClientCredentials requests an access token using the `client_credentials` grant, +// as defined in [RFC 6749, section 4.4]. +// +// As there is no user associated to the request an ID Token can never be returned. +// Client Credentials are undefined in OpenID Connect and is a pure OAuth2 grant. +// Furthermore the server SHOULD NOT return a refresh token. +// +// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 +func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) { + ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials") + ctx, span := client.Tracer.Start(ctx, "ClientCredentials") + defer span.End() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient()) + config := clientcredentials.Config{ + ClientID: rp.OAuthConfig().ClientID, + ClientSecret: rp.OAuthConfig().ClientSecret, + TokenURL: rp.OAuthConfig().Endpoint.TokenURL, + Scopes: rp.OAuthConfig().Scopes, + EndpointParams: endpointParams, + AuthStyle: rp.OAuthConfig().Endpoint.AuthStyle, } - - idTokenString, ok := token.Extra(idTokenKey).(string) - if !ok { - return nil, errors.New("id_token missing") - } - - idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) - if err != nil { - return nil, err - } - - return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil + return config.Token(ctx) } type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) @@ -413,17 +510,20 @@ type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.R // CodeExchangeHandler extends the `CodeExchange` method with a http handler // including cookie handling for secure `state` transfer // and optional PKCE code verifier checking. -// Custom paramaters can optionally be set to the token URL. +// Custom parameters can optionally be set to the token URL. func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + ctx, span := client.Tracer.Start(r.Context(), "CodeExchangeHandler") + r = r.WithContext(ctx) + defer span.End() + state, err := tryReadStateCookie(w, r, rp) if err != nil { - http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp) return } - params := r.URL.Query() - if params.Get("error") != "" { - rp.ErrorHandler()(w, r, params.Get("error"), params.Get("error_description"), state) + if errValue := r.FormValue("error"); errValue != "" { + rp.ErrorHandler()(w, r, errValue, r.FormValue("error_description"), state) return } codeOpts := make([]CodeExchangeOpt, len(urlParam)) @@ -434,57 +534,75 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R if rp.IsPKCE() { codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode) if err != nil { - http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) + rp.CookieHandler().DeleteCookie(w, pkceCode) } if rp.Signer() != nil { - assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer()) + assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer(), rp.OAuthConfig().Endpoint.TokenURL}, time.Hour, rp.Signer()) if err != nil { - http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp) return } codeOpts = append(codeOpts, WithClientAssertionJWT(assertion)) } - tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...) + tokens, err := CodeExchange[C](r.Context(), r.FormValue("code"), rp, codeOpts...) if err != nil { - http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp) return } callback(w, r, tokens, state, rp) } } -type CodeExchangeUserinfoCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info *oidc.UserInfo) +type SubjectGetter interface { + GetSubject() string +} + +type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U) // UserinfoCallback wraps the callback function of the CodeExchangeHandler // and calls the userinfo endpoint with the access token // on success it will pass the userinfo into its callback function as well -func UserinfoCallback[C oidc.IDClaims](f CodeExchangeUserinfoCallback[C]) CodeExchangeCallback[C] { +func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] { return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) { - info, err := Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) + ctx, span := client.Tracer.Start(r.Context(), "UserinfoCallback") + r = r.WithContext(ctx) + defer span.End() + + info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp) if err != nil { - http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized) + unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp) return } f(w, r, tokens, state, rp, info) } } -// Userinfo will call the OIDC Userinfo Endpoint with the provided token -func Userinfo(token, tokenType, subject string, rp RelyingParty) (*oidc.UserInfo, error) { - req, err := http.NewRequest("GET", rp.UserinfoEndpoint(), nil) +// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns +// the response in an instance of type U. +// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo +func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) { + var nilU U + ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo") + ctx, span := client.Tracer.Start(ctx, "Userinfo") + defer span.End() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil) if err != nil { - return nil, err + return nilU, err } req.Header.Set("authorization", tokenType+" "+token) - userinfo := new(oidc.UserInfo) if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil { - return nil, err + return nilU, err } - if userinfo.Subject != subject { - return nil, ErrUserInfoSubNotMatching + if userinfo.GetSubject() != subject { + return nilU, ErrUserInfoSubNotMatching } return userinfo, nil } @@ -525,9 +643,8 @@ type Endpoints struct { func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { return Endpoints{ Endpoint: oauth2.Endpoint{ - AuthURL: discoveryConfig.AuthorizationEndpoint, - AuthStyle: oauth2.AuthStyleAutoDetect, - TokenURL: discoveryConfig.TokenEndpoint, + AuthURL: discoveryConfig.AuthorizationEndpoint, + TokenURL: discoveryConfig.TokenEndpoint, }, IntrospectURL: discoveryConfig.IntrospectionEndpoint, UserinfoURL: discoveryConfig.UserinfoEndpoint, @@ -538,7 +655,7 @@ func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { } } -// withURLParam sets custom url paramaters. +// withURLParam sets custom url parameters. // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withURLParam(key, value string) func() []oauth2.AuthCodeOption { @@ -553,7 +670,7 @@ func withURLParam(key, value string) func() []oauth2.AuthCodeOption { // This is the generalized, unexported, function used by both // URLParamOpt and AuthURLOpt. func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption { - return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).Encode()) + return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String()) } type URLParamOpt func() []oauth2.AuthCodeOption @@ -569,6 +686,11 @@ func WithPromptURLParam(prompt ...string) URLParamOpt { return withPrompt(prompt...) } +// WithResponseModeURLParam sets the `response_mode` parameter in a URL. +func WithResponseModeURLParam(mode oidc.ResponseMode) URLParamOpt { + return withURLParam("response_mode", string(mode)) +} + type AuthURLOpt func() []oauth2.AuthCodeOption // WithCodeChallenge sets the `code_challenge` params in the auth request @@ -612,15 +734,26 @@ func (t tokenEndpointCaller) TokenEndpoint() string { type RefreshTokenRequest struct { RefreshToken string `schema:"refresh_token"` - Scopes oidc.SpaceDelimitedArray `schema:"scope"` - ClientID string `schema:"client_id"` - ClientSecret string `schema:"client_secret"` - ClientAssertion string `schema:"client_assertion"` - ClientAssertionType string `schema:"client_assertion_type"` + Scopes oidc.SpaceDelimitedArray `schema:"scope,omitempty"` + ClientID string `schema:"client_id,omitempty"` + ClientSecret string `schema:"client_secret,omitempty"` + ClientAssertion string `schema:"client_assertion,omitempty"` + ClientAssertionType string `schema:"client_assertion_type,omitempty"` GrantType oidc.GrantType `schema:"grant_type"` } -func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +// RefreshTokens performs a token refresh. If it doesn't error, it will always +// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then +// the old one should be considered invalid. +// +// In case the RP is not OAuth2 only and an IDToken was part of the response, +// the IDToken and AccessToken will be verified +// and the IDToken and IDTokenClaims fields will be populated in the returned object. +func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() + + ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens") request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -630,17 +763,31 @@ func RefreshAccessToken(rp RelyingParty, refreshToken, clientAssertion, clientAs ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(request, tokenEndpointCaller{RelyingParty: rp}) + newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + if err != nil { + return nil, err + } + tokens, err := verifyTokenResponse[C](ctx, newToken, rp) + if err == nil || errors.Is(err, ErrMissingIDToken) { + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + // ...except that it might not contain an id_token. + return tokens, nil + } + return nil, err } -func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { +func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { + ctx = logCtxWithRPData(ctx, rp, "function", "EndSession") + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() + request := oidc.EndSessionRequest{ IdTokenHint: idToken, ClientID: rp.OAuthConfig().ClientID, PostLogoutRedirectURI: optionalRedirectURI, State: optionalState, } - return client.CallEndSessionEndpoint(request, nil, rp) + return client.CallEndSessionEndpoint(ctx, request, nil, rp) } // RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty @@ -648,7 +795,10 @@ func EndSession(rp RelyingParty, idToken, optionalRedirectURI, optionalState str // NewRelyingPartyOAuth() does not. // // tokenTypeHint should be either "id_token" or "refresh_token". -func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { +func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error { + ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken") + ctx, span := client.Tracer.Start(ctx, "RefreshTokens") + defer span.End() request := client.RevokeRequest{ Token: token, TokenTypeHint: tokenTypeHint, @@ -656,7 +806,15 @@ func RevokeToken(rp RelyingParty, token string, tokenTypeHint string) error { ClientSecret: rp.OAuthConfig().ClientSecret, } if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" { - return client.CallRevokeEndpoint(request, nil, rc) + return client.CallRevokeEndpoint(ctx, request, nil, rc) } - return fmt.Errorf("RelyingParty does not support RevokeCaller") + return ErrRelyingPartyNotSupportRevokeCaller +} + +func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) { + if rp, ok := rp.(HasUnauthorizedHandler); ok { + rp.UnauthorizedHandler()(w, r, desc, state) + return + } + http.Error(w, desc, http.StatusUnauthorized) } diff --git a/pkg/client/rp/relying_party_test.go b/pkg/client/rp/relying_party_test.go new file mode 100644 index 0000000..b3bb6ee --- /dev/null +++ b/pkg/client/rp/relying_party_test.go @@ -0,0 +1,107 @@ +package rp + +import ( + "context" + "testing" + "time" + + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func Test_verifyTokenResponse(t *testing.T) { + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + ClientID: tu.ValidClientID, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + } + tests := []struct { + name string + oauth2Only bool + tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims]) + wantErr error + }{ + { + name: "succes, oauth2 only", + oauth2Only: true, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + }, + { + name: "id_token missing error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + wantErr: ErrMissingIDToken, + }, + { + name: "verify tokens error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + token = token.WithExtra(map[string]any{ + "id_token": "foobar", + }) + return token, nil + }, + wantErr: oidc.ErrParse, + }, + { + name: "success, with id_token", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + idToken, claims := tu.ValidIDToken() + token = token.WithExtra(map[string]any{ + "id_token": idToken, + }) + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + IDTokenClaims: claims, + IDToken: idToken, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := &relyingParty{ + oauth2Only: tt.oauth2Only, + idTokenVerifier: verifier, + } + token, want := tt.tokens() + got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, want, got) + }) + } +} diff --git a/pkg/client/rp/tockenexchange.go b/pkg/client/rp/tockenexchange.go index c1ac88d..aa2cf99 100644 --- a/pkg/client/rp/tockenexchange.go +++ b/pkg/client/rp/tockenexchange.go @@ -5,7 +5,7 @@ import ( "golang.org/x/oauth2" - "github.com/zitadel/oidc/v2/pkg/oidc/grants/tokenexchange" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc/grants/tokenexchange" ) // TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange` diff --git a/pkg/client/rp/userinfo_example_test.go b/pkg/client/rp/userinfo_example_test.go new file mode 100644 index 0000000..78e014e --- /dev/null +++ b/pkg/client/rp/userinfo_example_test.go @@ -0,0 +1,45 @@ +package rp_test + +import ( + "context" + "fmt" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" +) + +type UserInfo struct { + Subject string `json:"sub,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func (u *UserInfo) GetSubject() string { + return u.Subject +} + +func ExampleUserinfo_custom() { + rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone}) + if err != nil { + panic(err) + } + + info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo) + if err != nil { + panic(err) + } + + fmt.Println(info) +} diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go index 75d149b..0088b81 100644 --- a/pkg/client/rp/verifier.go +++ b/pkg/client/rp/verifier.go @@ -4,24 +4,18 @@ import ( "context" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) -type IDTokenVerifier interface { - oidc.Verifier - ClientID() string - SupportedSignAlgs() []string - KeySet() oidc.KeySet - Nonce(context.Context) string - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} - // VerifyTokens implement the Token Response Validation as defined in OIDC specification // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation -func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v IDTokenVerifier) (claims C, err error) { +func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) { + ctx, span := client.Tracer.Start(ctx, "VerifyTokens") + defer span.End() + var nilClaims C claims, err = VerifyIDToken[C](ctx, idToken, v) @@ -36,7 +30,10 @@ func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken str // VerifyIDToken validates the id token according to // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVerifier) (claims C, err error) { +func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) { + ctx, span := client.Tracer.Start(ctx, "VerifyIDToken") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -52,44 +49,48 @@ func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v IDTokenVe return nilClaims, err } - if err = oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err = oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckAudience(claims, v.ClientID()); err != nil { + if err = oidc.CheckAudience(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizedParty(claims, v.ClientID()); err != nil { + if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { return nilClaims, err } - if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil { + if v.Nonce != nil { + if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil { + return nilClaims, err + } + } + + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { - return nilClaims, err - } - - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { return nilClaims, err } return claims, nil } +type IDTokenVerifier oidc.Verifier + // VerifyAccessToken validates the access token according to // https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { @@ -107,15 +108,14 @@ func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAl return nil } -// NewIDTokenVerifier returns an implementation of `IDTokenVerifier` -// for `VerifyTokens` and `VerifyIDToken` -func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) IDTokenVerifier { - v := &idTokenVerifier{ - issuer: issuer, - clientID: clientID, - keySet: keySet, - offset: time.Second, - nonce: func(_ context.Context) string { +// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification. +func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier { + v := &IDTokenVerifier{ + Issuer: issuer, + ClientID: clientID, + KeySet: keySet, + Offset: time.Second, + Nonce: func(_ context.Context) string { return "" }, } @@ -128,95 +128,47 @@ func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ... } // VerifierOption is the type for providing dynamic options to the IDTokenVerifier -type VerifierOption func(*idTokenVerifier) +type VerifierOption func(*IDTokenVerifier) // WithIssuedAtOffset mitigates the risk of iat to be in the future // because of clock skews with the ability to add an offset to the current time -func WithIssuedAtOffset(offset time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.offset = offset +func WithIssuedAtOffset(offset time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.Offset = offset } } // WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now -func WithIssuedAtMaxAge(maxAge time.Duration) func(*idTokenVerifier) { - return func(v *idTokenVerifier) { - v.maxAgeIAT = maxAge +func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption { + return func(v *IDTokenVerifier) { + v.MaxAgeIAT = maxAge } } // WithNonce sets the function to check the nonce func WithNonce(nonce func(context.Context) string) VerifierOption { - return func(v *idTokenVerifier) { - v.nonce = nonce + return func(v *IDTokenVerifier) { + v.Nonce = nonce } } // WithACRVerifier sets the verifier for the acr claim func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption { - return func(v *idTokenVerifier) { - v.acr = verifier + return func(v *IDTokenVerifier) { + v.ACR = verifier } } // WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption { - return func(v *idTokenVerifier) { - v.maxAge = maxAge + return func(v *IDTokenVerifier) { + v.MaxAge = maxAge } } // WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm func WithSupportedSigningAlgorithms(algs ...string) VerifierOption { - return func(v *idTokenVerifier) { - v.supportedSignAlgs = algs + return func(v *IDTokenVerifier) { + v.SupportedSignAlgs = algs } } - -type idTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - clientID string - supportedSignAlgs []string - keySet oidc.KeySet - acr oidc.ACRVerifier - maxAge time.Duration - nonce func(ctx context.Context) string -} - -func (i *idTokenVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenVerifier) ClientID() string { - return i.clientID -} - -func (i *idTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenVerifier) Nonce(ctx context.Context) string { - return i.nonce(ctx) -} - -func (i *idTokenVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenVerifier) MaxAge() time.Duration { - return i.maxAge -} diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go index 7588c1f..38f5a4a 100644 --- a/pkg/client/rp/verifier_test.go +++ b/pkg/client/rp/verifier_test.go @@ -5,24 +5,24 @@ import ( "testing" "time" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + jose "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" - "gopkg.in/square/go-jose.v2" ) func TestVerifyTokens(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, - clientID: tu.ValidClientID, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } accessToken, _ := tu.ValidAccessToken() atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm) @@ -91,43 +91,64 @@ func TestVerifyTokens(t *testing.T) { } func TestVerifyIDToken(t *testing.T) { - verifier := &idTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - nonce: func(context.Context) string { return tu.ValidNonce }, + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + ClientID: tu.ValidClientID, } tests := []struct { - name string - clientID string - tokenClaims func() (string, *oidc.IDTokenClaims) - wantErr bool + name string + tokenClaims func() (string, *oidc.IDTokenClaims) + customVerifier func(verifier *IDTokenVerifier) + wantErr bool }{ { name: "success", - clientID: tu.ValidClientID, tokenClaims: tu.ValidIDToken, }, + { + name: "custom claims", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDTokenCustom( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + map[string]any{"some": "thing"}, + ) + }, + }, + { + name: "skip nonce check", + customVerifier: func(verifier *IDTokenVerifier) { + verifier.Nonce = nil + }, + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, "foo", + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + }, { name: "parse err", - clientID: tu.ValidClientID, tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, wantErr: true, }, { name: "invalid signature", - clientID: tu.ValidClientID, tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, wantErr: true, }, { - name: "empty subject", - clientID: tu.ValidClientID, + name: "empty subject", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, "", tu.ValidAudience, @@ -138,8 +159,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong issuer", - clientID: tu.ValidClientID, + name: "wrong issuer", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( "foo", tu.ValidSubject, tu.ValidAudience, @@ -150,14 +170,15 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong clientID", - clientID: "foo", + name: "wrong clientID", + customVerifier: func(verifier *IDTokenVerifier) { + verifier.ClientID = "foo" + }, tokenClaims: tu.ValidIDToken, wantErr: true, }, { - name: "expired", - clientID: tu.ValidClientID, + name: "expired", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -168,8 +189,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong IAT", - clientID: tu.ValidClientID, + name: "wrong IAT", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -180,8 +200,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong acr", - clientID: tu.ValidClientID, + name: "wrong acr", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -192,8 +211,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "expired auth", - clientID: tu.ValidClientID, + name: "expired auth", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -204,8 +222,7 @@ func TestVerifyIDToken(t *testing.T) { wantErr: true, }, { - name: "wrong nonce", - clientID: tu.ValidClientID, + name: "wrong nonce", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.NewIDToken( tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, @@ -219,7 +236,10 @@ func TestVerifyIDToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { token, want := tt.tokenClaims() - verifier.clientID = tt.clientID + if tt.customVerifier != nil { + tt.customVerifier(verifier) + } + got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier) if tt.wantErr { assert.Error(t, err) @@ -300,7 +320,7 @@ func TestNewIDTokenVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenVerifier + want *IDTokenVerifier }{ { name: "nil nonce", // otherwise assert.Equal will fail on the function @@ -317,16 +337,16 @@ func TestNewIDTokenVerifier(t *testing.T) { WithSupportedSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenVerifier{ - issuer: tu.ValidIssuer, - offset: time.Minute, - maxAgeIAT: time.Hour, - clientID: tu.ValidClientID, - keySet: tu.KeySet{}, - nonce: nil, - acr: nil, - maxAge: 2 * time.Hour, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + Offset: time.Minute, + MaxAgeIAT: time.Hour, + ClientID: tu.ValidClientID, + KeySet: tu.KeySet{}, + Nonce: nil, + ACR: nil, + MaxAge: 2 * time.Hour, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } diff --git a/pkg/client/rp/verifier_tokens_example_test.go b/pkg/client/rp/verifier_tokens_example_test.go index c297efe..7ae68d6 100644 --- a/pkg/client/rp/verifier_tokens_example_test.go +++ b/pkg/client/rp/verifier_tokens_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/client/rp" - "github.com/zitadel/oidc/v2/pkg/oidc" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/client/rs/introspect_example_test.go b/pkg/client/rs/introspect_example_test.go new file mode 100644 index 0000000..1f67d11 --- /dev/null +++ b/pkg/client/rs/introspect_example_test.go @@ -0,0 +1,52 @@ +package rs_test + +import ( + "context" + "fmt" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" +) + +type IntrospectionResponse struct { + Active bool `json:"active"` + Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + TokenType string `json:"token_type,omitempty"` + Expiration oidc.Time `json:"exp,omitempty"` + IssuedAt oidc.Time `json:"iat,omitempty"` + NotBefore oidc.Time `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` + Audience oidc.Audience `json:"aud,omitempty"` + Issuer string `json:"iss,omitempty"` + JWTID string `json:"jti,omitempty"` + Username string `json:"username,omitempty"` + oidc.UserInfoProfile + oidc.UserInfoEmail + oidc.UserInfoPhone + Address *oidc.UserInfoAddress `json:"address,omitempty"` + + // Foo and Bar are custom claims + Foo string `json:"foo,omitempty"` + Bar struct { + Val1 string `json:"val_1,omitempty"` + Val2 string `json:"val_2,omitempty"` + } `json:"bar,omitempty"` + + // Claims are all the combined claims, including custom. + Claims map[string]any `json:"-,omitempty"` +} + +func ExampleIntrospect_custom() { + rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret") + if err != nil { + panic(err) + } + + resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring") + if err != nil { + panic(err) + } + + fmt.Println(resp) +} diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go index 4e0353c..993796e 100644 --- a/pkg/client/rs/resource_server.go +++ b/pkg/client/rs/resource_server.go @@ -6,16 +6,16 @@ import ( "net/http" "time" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type ResourceServer interface { IntrospectionURL() string TokenEndpoint() string HttpClient() *http.Client - AuthFn() (interface{}, error) + AuthFn() (any, error) } type resourceServer struct { @@ -23,7 +23,7 @@ type resourceServer struct { tokenURL string introspectURL string httpClient *http.Client - authFn func() (interface{}, error) + authFn func() (any, error) } func (r *resourceServer) IntrospectionURL() string { @@ -38,33 +38,33 @@ func (r *resourceServer) HttpClient() *http.Client { return r.httpClient } -func (r *resourceServer) AuthFn() (interface{}, error) { +func (r *resourceServer) AuthFn() (any, error) { return r.authFn() } -func NewResourceServerClientCredentials(issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { - authorizer := func() (interface{}, error) { +func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) { + authorizer := func() (any, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newResourceServer(issuer, authorizer, option...) + return newResourceServer(ctx, issuer, authorizer, option...) } -func NewResourceServerJWTProfile(issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { +func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) { signer, err := client.NewSignerFromPrivateKeyByte(key, keyID) if err != nil { return nil, err } - authorizer := func() (interface{}, error) { + authorizer := func() (any, error) { assertion, err := client.SignedJWTProfileAssertion(clientID, []string{issuer}, time.Hour, signer) if err != nil { return nil, err } return client.ClientAssertionFormAuthorization(assertion), nil } - return newResourceServer(issuer, authorizer, options...) + return newResourceServer(ctx, issuer, authorizer, options...) } -func newResourceServer(issuer string, authorizer func() (interface{}, error), options ...Option) (*resourceServer, error) { +func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) { rs := &resourceServer{ issuer: issuer, httpClient: httphelper.DefaultHTTPClient, @@ -73,26 +73,30 @@ func newResourceServer(issuer string, authorizer func() (interface{}, error), op optFunc(rs) } if rs.introspectURL == "" || rs.tokenURL == "" { - config, err := client.Discover(rs.issuer, rs.httpClient) + config, err := client.Discover(ctx, rs.issuer, rs.httpClient) if err != nil { return nil, err } - rs.tokenURL = config.TokenEndpoint - rs.introspectURL = config.IntrospectionEndpoint + if rs.tokenURL == "" { + rs.tokenURL = config.TokenEndpoint + } + if rs.introspectURL == "" { + rs.introspectURL = config.IntrospectionEndpoint + } } - if rs.introspectURL == "" || rs.tokenURL == "" { - return nil, errors.New("introspectURL and/or tokenURL is empty: please provide with either `WithStaticEndpoints` or a discovery url") + if rs.tokenURL == "" { + return nil, errors.New("tokenURL is empty: please provide with either `WithStaticEndpoints` or a discovery url") } rs.authFn = authorizer return rs, nil } -func NewResourceServerFromKeyFile(issuer, path string, options ...Option) (ResourceServer, error) { +func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) { c, err := client.ConfigFromKeyFile(path) if err != nil { return nil, err } - return NewResourceServerJWTProfile(issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) + return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...) } type Option func(*resourceServer) @@ -112,18 +116,30 @@ func WithStaticEndpoints(tokenURL, introspectURL string) Option { } } -func Introspect(ctx context.Context, rp ResourceServer, token string) (*oidc.IntrospectionResponse, error) { +// Introspect calls the [RFC7662] Token Introspection +// endpoint and returns the response in an instance of type R. +// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe +// access to custom claims is needed. +// +// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662 +func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) { + ctx, span := client.Tracer.Start(ctx, "Introspect") + defer span.End() + + if rp.IntrospectionURL() == "" { + return resp, errors.New("resource server: introspection URL is empty") + } authFn, err := rp.AuthFn() if err != nil { - return nil, err + return resp, err } - req, err := httphelper.FormRequest(rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) + req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn) if err != nil { - return nil, err + return resp, err } - resp := new(oidc.IntrospectionResponse) - if err := httphelper.HttpRequest(rp.HttpClient(), req, resp); err != nil { - return nil, err + + if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil { + return resp, err } return resp, nil } diff --git a/pkg/client/rs/resource_server_test.go b/pkg/client/rs/resource_server_test.go new file mode 100644 index 0000000..afd7441 --- /dev/null +++ b/pkg/client/rs/resource_server_test.go @@ -0,0 +1,221 @@ +package rs + +import ( + "context" + "testing" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewResourceServer(t *testing.T) { + type args struct { + issuer string + authorizer func() (any, error) + options []Option + } + type wantFields struct { + issuer string + tokenURL string + introspectURL string + authFn func() (any, error) + } + tests := []struct { + name string + args args + wantFields *wantFields + wantErr bool + }{ + { + name: "spotify-full-discovery", + args: args{ + issuer: "https://accounts.spotify.com", + authorizer: nil, + options: []Option{}, + }, + wantFields: &wantFields{ + issuer: "https://accounts.spotify.com", + tokenURL: "https://accounts.spotify.com/api/token", + introspectURL: "", + authFn: nil, + }, + wantErr: false, + }, + { + name: "spotify-with-static-tokenurl", + args: args{ + issuer: "https://accounts.spotify.com", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "https://some.host/token-url", + "", + ), + }, + }, + wantFields: &wantFields{ + issuer: "https://accounts.spotify.com", + tokenURL: "https://some.host/token-url", + introspectURL: "", + authFn: nil, + }, + wantErr: false, + }, + { + name: "spotify-with-static-introspecturl", + args: args{ + issuer: "https://accounts.spotify.com", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "", + "https://some.host/instrospect-url", + ), + }, + }, + wantFields: &wantFields{ + issuer: "https://accounts.spotify.com", + tokenURL: "https://accounts.spotify.com/api/token", + introspectURL: "https://some.host/instrospect-url", + authFn: nil, + }, + wantErr: false, + }, + { + name: "spotify-with-all-static-endpoints", + args: args{ + issuer: "https://accounts.spotify.com", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "https://some.host/token-url", + "https://some.host/instrospect-url", + ), + }, + }, + wantFields: &wantFields{ + issuer: "https://accounts.spotify.com", + tokenURL: "https://some.host/token-url", + introspectURL: "https://some.host/instrospect-url", + authFn: nil, + }, + wantErr: false, + }, + { + name: "bad-discovery", + args: args{ + issuer: "https://127.0.0.1:65535", + authorizer: nil, + options: []Option{}, + }, + wantFields: nil, + wantErr: true, + }, + { + name: "bad-discovery-with-static-tokenurl", + args: args{ + issuer: "https://127.0.0.1:65535", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "https://some.host/token-url", + "", + ), + }, + }, + wantFields: nil, + wantErr: true, + }, + { + name: "bad-discovery-with-static-introspecturl", + args: args{ + issuer: "https://127.0.0.1:65535", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "", + "https://some.host/instrospect-url", + ), + }, + }, + wantFields: nil, + wantErr: true, + }, + { + name: "bad-discovery-with-all-static-endpoints", + args: args{ + issuer: "https://127.0.0.1:65535", + authorizer: nil, + options: []Option{ + WithStaticEndpoints( + "https://some.host/token-url", + "https://some.host/instrospect-url", + ), + }, + }, + wantFields: &wantFields{ + issuer: "https://127.0.0.1:65535", + tokenURL: "https://some.host/token-url", + introspectURL: "https://some.host/instrospect-url", + authFn: nil, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + if tt.wantFields == nil { + return + } + assert.Equal(t, tt.wantFields.issuer, got.issuer) + assert.Equal(t, tt.wantFields.tokenURL, got.tokenURL) + assert.Equal(t, tt.wantFields.introspectURL, got.introspectURL) + }) + } +} + +func TestIntrospect(t *testing.T) { + type args struct { + ctx context.Context + rp ResourceServer + token string + } + rp, err := newResourceServer( + context.Background(), + "https://accounts.spotify.com", + nil, + ) + require.NoError(t, err) + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "missing-introspect-url", + args: args{ + ctx: context.Background(), + rp: rp, + token: "my-token", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go index 1375f68..9cc1328 100644 --- a/pkg/client/tokenexchange/tokenexchange.go +++ b/pkg/client/tokenexchange/tokenexchange.go @@ -1,38 +1,52 @@ package tokenexchange import ( + "context" "errors" "net/http" + "time" - "github.com/zitadel/oidc/v2/pkg/client" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/go-jose/go-jose/v4" ) type TokenExchanger interface { TokenEndpoint() string HttpClient() *http.Client - AuthFn() (interface{}, error) + AuthFn() (any, error) } type OAuthTokenExchange struct { httpClient *http.Client tokenEndpoint string - authFn func() (interface{}, error) + authFn func() (any, error) } -func NewTokenExchanger(issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { - return newOAuthTokenExchange(issuer, nil, options...) +func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + return newOAuthTokenExchange(ctx, issuer, nil, options...) } -func NewTokenExchangerClientCredentials(issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { - authorizer := func() (interface{}, error) { +func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + authorizer := func() (any, error) { return httphelper.AuthorizeBasic(clientID, clientSecret), nil } - return newOAuthTokenExchange(issuer, authorizer, options...) + return newOAuthTokenExchange(ctx, issuer, authorizer, options...) } -func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { +func NewTokenExchangerJWTProfile(ctx context.Context, issuer, clientID string, signer jose.Signer, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) { + authorizer := func() (any, error) { + assertion, err := client.SignedJWTProfileAssertion(clientID, []string{issuer}, time.Hour, signer) + if err != nil { + return nil, err + } + return client.ClientAssertionFormAuthorization(assertion), nil + } + return newOAuthTokenExchange(ctx, issuer, authorizer, options...) +} + +func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) { te := &OAuthTokenExchange{ httpClient: httphelper.DefaultHTTPClient, } @@ -41,7 +55,7 @@ func newOAuthTokenExchange(issuer string, authorizer func() (interface{}, error) } if te.tokenEndpoint == "" { - config, err := client.Discover(issuer, te.httpClient) + config, err := client.Discover(ctx, issuer, te.httpClient) if err != nil { return nil, err } @@ -78,7 +92,7 @@ func (te *OAuthTokenExchange) HttpClient() *http.Client { return te.httpClient } -func (te *OAuthTokenExchange) AuthFn() (interface{}, error) { +func (te *OAuthTokenExchange) AuthFn() (any, error) { if te.authFn != nil { return te.authFn() } @@ -89,6 +103,7 @@ func (te *OAuthTokenExchange) AuthFn() (interface{}, error) { // ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint. // SubjectToken and SubjectTokenType are required parameters. func ExchangeToken( + ctx context.Context, te TokenExchanger, SubjectToken string, SubjectTokenType oidc.TokenType, @@ -99,6 +114,9 @@ func ExchangeToken( Scopes []string, RequestedTokenType oidc.TokenType, ) (*oidc.TokenExchangeResponse, error) { + ctx, span := client.Tracer.Start(ctx, "ExchangeToken") + defer span.End() + if SubjectToken == "" { return nil, errors.New("empty subject_token") } @@ -123,5 +141,5 @@ func ExchangeToken( RequestedTokenType: RequestedTokenType, } - return client.CallTokenExchangeEndpoint(request, authFn, te) + return client.CallTokenExchangeEndpoint(ctx, request, authFn, te) } diff --git a/pkg/crypto/hash.go b/pkg/crypto/hash.go index 6fcc71f..14acdee 100644 --- a/pkg/crypto/hash.go +++ b/pkg/crypto/hash.go @@ -8,7 +8,7 @@ import ( "fmt" "hash" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm") @@ -21,6 +21,14 @@ func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { return sha512.New384(), nil case jose.RS512, jose.ES512, jose.PS512: return sha512.New(), nil + + // There is no published spec for this yet, but we have confirmation it will get published. + // There is consensus here: https://bitbucket.org/openid/connect/issues/1125/_hash-algorithm-for-eddsa-id-tokens + // Currently Go and go-jose only supports the ed25519 curve key for EdDSA, so we can safely assume sha512 here. + // It is unlikely ed448 will ever be supported: https://github.com/golang/go/issues/29390 + case jose.EdDSA: + return sha512.New(), nil + default: return nil, fmt.Errorf("%w: %q", ErrUnsupportedAlgorithm, sigAlgorithm) } diff --git a/pkg/crypto/key.go b/pkg/crypto/key.go index d75d1ab..12bca28 100644 --- a/pkg/crypto/key.go +++ b/pkg/crypto/key.go @@ -1,17 +1,45 @@ package crypto import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" + + "github.com/go-jose/go-jose/v4" ) -func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) { - block, _ := pem.Decode(priv) - b := block.Bytes - key, err := x509.ParsePKCS1PrivateKey(b) - if err != nil { - return nil, err +var ( + ErrPEMDecode = errors.New("PEM decode failed") + ErrUnsupportedFormat = errors.New("key is neither in PKCS#1 nor PKCS#8 format") + ErrUnsupportedPrivateKey = errors.New("unsupported key type, must be RSA, ECDSA or ED25519 private key") +) + +func BytesToPrivateKey(b []byte) (crypto.PublicKey, jose.SignatureAlgorithm, error) { + block, _ := pem.Decode(b) + if block == nil { + return nil, "", ErrPEMDecode + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err == nil { + return privateKey, jose.RS256, nil + } + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, "", ErrUnsupportedFormat + } + switch privateKey := key.(type) { + case *rsa.PrivateKey: + return privateKey, jose.RS256, nil + case ed25519.PrivateKey: + return privateKey, jose.EdDSA, nil + case *ecdsa.PrivateKey: + return privateKey, jose.ES256, nil + default: + return nil, "", ErrUnsupportedPrivateKey } - return key, nil } diff --git a/pkg/crypto/key_test.go b/pkg/crypto/key_test.go new file mode 100644 index 0000000..a6fa493 --- /dev/null +++ b/pkg/crypto/key_test.go @@ -0,0 +1,134 @@ +package crypto_test + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "testing" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/assert" + + zcrypto "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" +) + +func TestBytesToPrivateKey(t *testing.T) { + type args struct { + key []byte + } + type want struct { + key crypto.Signer + algorithm jose.SignatureAlgorithm + err error + } + tests := []struct { + name string + args args + want want + }{ + { + name: "PEMDecodeError", + args: args{ + key: []byte("The non-PEM sequence"), + }, + want: want{ + err: zcrypto.ErrPEMDecode, + }, + }, + { + name: "PKCS#1 RSA", + args: args{ + key: []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu +KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm +o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k +TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7 +9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy +v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs +/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00 +-----END RSA PRIVATE KEY-----`), + }, + want: want{ + key: &rsa.PrivateKey{}, + algorithm: jose.RS256, + err: nil, + }, + }, + { + name: "PKCS#8 RSA", + args: args{ + key: []byte(`-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCfaDB7pK/fmP/I +7IusSK8lTCBnPZghqIbVLt2QHYAMoEF1CaF4F4rxo2vl1Mt8gwsq4T3osQFZMvnL +YHb7KNyUoJgTjLxJQADv2u4Q3U38heAzK5Tp4ry4MCnuyJIqAPK1GiruwEq4zQrx ++WzVix8otO37SuW9tzklqlNGMiAYBL0TBKHvS5XMbjP1idBMB8erMz29w/TVQnEB +Kj0vCdZjrbVPKygptt5kcSrL5f4xCZwU+ufz7cp0GLwpRMJ+shG9YJJFBxb0itPF +sy51vAyEtdBC7jgAU96ZVeQ06nryDq1D2EpoVMElqNyL46Jo3lnKbGquGKzXzQYU +BN32/scDAgMBAAECggEBAJE/mo3PLgILo2YtQ8ekIxNVHmF0Gl7w9IrjvTdH6hmX +HI3MTLjkmtI7GmG9V/0IWvCjdInGX3grnrjWGRQZ04QKIQgPQLFuBGyJjEsJm7nx +MqztlS7YTyV1nX/aenSTkJO8WEpcJLnm+4YoxCaAMdAhrIdBY71OamALpv1bRysa +FaiCGcemT2yqZn0GqIS8O26Tz5zIqrTN2G1eSmgh7DG+7FoddMz35cute8R10xUG +hF5YU+6fcXiRQ/Kh7nlxelPGqdZFPMk7LpVHzkQKwdJ+N0P23lPDIfNsvpG1n0OP +3g5km7gHSrSU2yZ3eFl6DB9x1IFNS9BaQQuSxYJtKwECgYEA1C8jjzpXZDLvlYsV +2jlMzkrbsIrX2dzblVrNsPs2jRbjYU8mg2DUDO6lOhtxHfqZG6sO+gmWi/zvoy9l +yolGbXe1Jqx66p9fznIcecSwar8+ACa356Wk74Nt1PlBOfCMqaJnYLOLaFJa29Vy +u5ClZVzKd5AVXl7yFVd4XfLv/WECgYEAwFMMtFoasdF92c0d31rZ1uoPOtFz6xq6 +uQggdm5zzkhnfwUAGqppS/u1CHcJ7T/74++jLbFTsaohGr4jEzWSGvJpomEUChy3 +r25YofMclUhJ5pCEStsLtqiCR1Am6LlI8HMdBEP1QDgEC5q8bQW4+UHuew1E1zxz +osZOhe09WuMCgYEA0G9aFCnwjUqIFjQiDFP7gi8BLqTFs4uE3Wvs4W11whV42i+B +ms90nxuTjchFT3jMDOT1+mOO0wdudLRr3xEI8SIF/u6ydGaJG+j21huEXehtxIJE +aDdNFcfbDbqo+3y1ATK7MMBPMvSrsoY0hdJq127WqasNgr3sO1DIuima3SECgYEA +nkM5TyhekzlbIOHD1UsDu/D7+2DkzPE/+oePfyXBMl0unb3VqhvVbmuBO6gJiSx/ +8b//PdiQkMD5YPJaFrKcuoQFHVRZk0CyfzCEyzAts0K7XXpLAvZiGztriZeRjSz7 +srJnjF0H8oKmAY6hw+1Tm/n/b08p+RyL48TgVSE2vhUCgYA3BWpkD4PlCcn/FZsq +OrLFyFXI6jIaxskFtsRW1IxxIlAdZmxfB26P/2gx6VjLdxJI/RRPkJyEN2dP7CbR +BDjb565dy1O9D6+UrY70Iuwjz+OcALRBBGTaiF2pLn6IhSzNI2sy/tXX8q8dBlg9 +OFCrqT/emes3KytTPfa5NZtYeQ== +-----END PRIVATE KEY-----`), + }, + want: want{ + key: &rsa.PrivateKey{}, + algorithm: jose.RS256, + err: nil, + }, + }, + { + name: "PKCS#8 ECDSA", + args: args{ + key: []byte(`-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgwwOZSU4GlP7ps/Wp +V6o0qRwxultdfYo/uUuj48QZjSuhRANCAATMiI2Han+ABKmrk5CNlxRAGC61w4d3 +G4TAeuBpyzqJ7x/6NjCxoQzJzZHtNjIfjVATI59XFZWF59GhtSZbShAr +-----END PRIVATE KEY-----`), + }, + want: want{ + key: &ecdsa.PrivateKey{}, + algorithm: jose.ES256, + err: nil, + }, + }, + { + name: "PKCS#8 ED25519", + args: args{ + key: []byte(`-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIHu6ZtDsjjauMasBxnS9Fg87UJwKfcT/oiq6S0ktbky8 +-----END PRIVATE KEY-----`), + }, + want: want{ + key: ed25519.PrivateKey{}, + algorithm: jose.EdDSA, + err: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key, algorithm, err := zcrypto.BytesToPrivateKey(tt.args.key) + assert.IsType(t, tt.want.key, key) + assert.Equal(t, tt.want.algorithm, algorithm) + assert.ErrorIs(t, tt.want.err, err) + }) + + } +} diff --git a/pkg/crypto/sign.go b/pkg/crypto/sign.go index a0b9cae..937a846 100644 --- a/pkg/crypto/sign.go +++ b/pkg/crypto/sign.go @@ -4,10 +4,10 @@ import ( "encoding/json" "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) -func Sign(object interface{}, signer jose.Signer) (string, error) { +func Sign(object any, signer jose.Signer) (string, error) { payload, err := json.Marshal(object) if err != nil { return "", err diff --git a/pkg/http/http.go b/pkg/http/http.go index d3c5b4f..aa0ff6f 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -10,6 +10,8 @@ import ( "net/url" "strings" "time" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) var DefaultHTTPClient = &http.Client{ @@ -17,11 +19,11 @@ var DefaultHTTPClient = &http.Client{ } type Decoder interface { - Decode(dst interface{}, src map[string][]string) error + Decode(dst any, src map[string][]string) error } type Encoder interface { - Encode(src interface{}, dst map[string][]string) error + Encode(src any, dst map[string][]string) error } type FormAuthorization func(url.Values) @@ -33,7 +35,7 @@ func AuthorizeBasic(user, password string) RequestAuthorization { } } -func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) { +func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) { form := url.Values{} if err := encoder.Encode(request, form); err != nil { return nil, err @@ -42,7 +44,7 @@ func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn i fn(form) } body := strings.NewReader(form.Encode()) - req, err := http.NewRequest("POST", endpoint, body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) if err != nil { return nil, err } @@ -53,7 +55,7 @@ func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn i return req, nil } -func HttpRequest(client *http.Client, req *http.Request, response interface{}) error { +func HttpRequest(client *http.Client, req *http.Request, response any) error { resp, err := client.Do(req) if err != nil { return err @@ -66,7 +68,12 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e } if resp.StatusCode != http.StatusOK { - return fmt.Errorf("http status not ok: %s %s", resp.Status, body) + var oidcErr oidc.Error + err = json.Unmarshal(body, &oidcErr) + if err != nil || oidcErr.ErrorType == "" { + return fmt.Errorf("http status not ok: %s %s", resp.Status, body) + } + return &oidcErr } err = json.Unmarshal(body, response) @@ -76,7 +83,7 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e return nil } -func URLEncodeParams(resp interface{}, encoder Encoder) (url.Values, error) { +func URLEncodeParams(resp any, encoder Encoder) (url.Values, error) { values := make(map[string][]string) err := encoder.Encode(resp, values) if err != nil { diff --git a/pkg/http/marshal.go b/pkg/http/marshal.go index 794a28a..71ed2c2 100644 --- a/pkg/http/marshal.go +++ b/pkg/http/marshal.go @@ -8,11 +8,11 @@ import ( "reflect" ) -func MarshalJSON(w http.ResponseWriter, i interface{}) { +func MarshalJSON(w http.ResponseWriter, i any) { MarshalJSONWithStatus(w, i, http.StatusOK) } -func MarshalJSONWithStatus(w http.ResponseWriter, i interface{}, status int) { +func MarshalJSONWithStatus(w http.ResponseWriter, i any, status int) { w.Header().Set("content-type", "application/json") w.WriteHeader(status) if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) { diff --git a/pkg/http/marshal_test.go b/pkg/http/marshal_test.go index 3838a44..dcc7fdd 100644 --- a/pkg/http/marshal_test.go +++ b/pkg/http/marshal_test.go @@ -94,7 +94,7 @@ func TestConcatenateJSON(t *testing.T) { func TestMarshalJSONWithStatus(t *testing.T) { type args struct { - i interface{} + i any status int } type res struct { diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go index f620ecb..fa37dbf 100644 --- a/pkg/oidc/authorization.go +++ b/pkg/oidc/authorization.go @@ -1,5 +1,9 @@ package oidc +import ( + "log/slog" +) + const ( // ScopeOpenID defines the scope `openid` // OpenID Connect requests MUST contain the `openid` scope value @@ -44,6 +48,7 @@ const ( ResponseModeQuery ResponseMode = "query" ResponseModeFragment ResponseMode = "fragment" + ResponseModeFormPost ResponseMode = "form_post" // PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages. // An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed @@ -60,7 +65,7 @@ const ( ) // AuthRequest according to: -//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest +// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type AuthRequest struct { Scopes SpaceDelimitedArray `json:"scope" schema:"scope"` ResponseType ResponseType `json:"response_type" schema:"response_type"` @@ -77,7 +82,7 @@ type AuthRequest struct { UILocales Locales `json:"ui_locales" schema:"ui_locales"` IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"` LoginHint string `json:"login_hint" schema:"login_hint"` - ACRValues []string `json:"acr_values" schema:"acr_values"` + ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"` CodeChallenge string `json:"code_challenge" schema:"code_challenge"` CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"` @@ -86,6 +91,15 @@ type AuthRequest struct { RequestParam string `schema:"request"` } +func (a *AuthRequest) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("scopes", a.Scopes), + slog.String("response_type", string(a.ResponseType)), + slog.String("client_id", a.ClientID), + slog.String("redirect_uri", a.RedirectURI), + ) +} + // GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface func (a *AuthRequest) GetRedirectURI() string { return a.RedirectURI @@ -100,3 +114,8 @@ func (a *AuthRequest) GetResponseType() ResponseType { func (a *AuthRequest) GetState() string { return a.State } + +// GetResponseMode returns the optional ResponseMode +func (a *AuthRequest) GetResponseMode() ResponseMode { + return a.ResponseMode +} diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go new file mode 100644 index 0000000..1446efa --- /dev/null +++ b/pkg/oidc/authorization_test.go @@ -0,0 +1,27 @@ +//go:build go1.20 + +package oidc + +import ( + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAuthRequest_LogValue(t *testing.T) { + a := &AuthRequest{ + Scopes: SpaceDelimitedArray{"a", "b"}, + ResponseType: "respType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + } + want := slog.GroupValue( + slog.Any("scopes", SpaceDelimitedArray{"a", "b"}), + slog.String("response_type", "respType"), + slog.String("client_id", "123"), + slog.String("redirect_uri", "http://example.com/callback"), + ) + got := a.LogValue() + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go index 37c1783..0c593df 100644 --- a/pkg/oidc/code_challenge.go +++ b/pkg/oidc/code_challenge.go @@ -3,7 +3,7 @@ package oidc import ( "crypto/sha256" - "github.com/zitadel/oidc/v2/pkg/crypto" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" ) const ( diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go index 68b8efa..a6417ba 100644 --- a/pkg/oidc/device_authorization.go +++ b/pkg/oidc/device_authorization.go @@ -1,5 +1,7 @@ package oidc +import "encoding/json" + // DeviceAuthorizationRequest implements // https://www.rfc-editor.org/rfc/rfc8628#section-3.1, // 3.1 Device Authorization Request. @@ -20,6 +22,26 @@ type DeviceAuthorizationResponse struct { Interval int `json:"interval,omitempty"` } +func (resp *DeviceAuthorizationResponse) UnmarshalJSON(data []byte) error { + type Alias DeviceAuthorizationResponse + aux := &struct { + // workaround misspelling of verification_uri + // https://stackoverflow.com/q/76696956/5690223 + // https://developers.google.com/identity/protocols/oauth2/limited-input-device?hl=fr#success-response + VerificationURL string `json:"verification_url"` + *Alias + }{ + Alias: (*Alias)(resp), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if resp.VerificationURI == "" { + resp.VerificationURI = aux.VerificationURL + } + return nil +} + // DeviceAccessTokenRequest implements // https://www.rfc-editor.org/rfc/rfc8628#section-3.4, // Device Access Token Request. diff --git a/pkg/oidc/device_authorization_test.go b/pkg/oidc/device_authorization_test.go new file mode 100644 index 0000000..c4c6637 --- /dev/null +++ b/pkg/oidc/device_authorization_test.go @@ -0,0 +1,30 @@ +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDeviceAuthorizationResponse_UnmarshalJSON(t *testing.T) { + jsonStr := `{ + "device_code": "deviceCode", + "user_code": "userCode", + "verification_url": "http://example.com/verify", + "expires_in": 3600, + "interval": 5 + }` + + expected := &DeviceAuthorizationResponse{ + DeviceCode: "deviceCode", + UserCode: "userCode", + VerificationURI: "http://example.com/verify", + ExpiresIn: 3600, + Interval: 5, + } + + var resp DeviceAuthorizationResponse + err := resp.UnmarshalJSON([]byte(jsonStr)) + assert.NoError(t, err) + assert.Equal(t, expected, &resp) +} diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go index 3574101..62288d1 100644 --- a/pkg/oidc/discovery.go +++ b/pkg/oidc/discovery.go @@ -1,9 +1,5 @@ package oidc -import ( - "golang.org/x/text/language" -) - const ( DiscoveryEndpoint = "/.well-known/openid-configuration" ) @@ -130,10 +126,10 @@ type DiscoveryConfiguration struct { ServiceDocumentation string `json:"service_documentation,omitempty"` // ClaimsLocalesSupported contains a list of BCP47 language tag values that the OP supports for values of Claims returned. - ClaimsLocalesSupported []language.Tag `json:"claims_locales_supported,omitempty"` + ClaimsLocalesSupported Locales `json:"claims_locales_supported,omitempty"` // UILocalesSupported contains a list of BCP47 language tag values that the OP supports for the user interface. - UILocalesSupported []language.Tag `json:"ui_locales_supported,omitempty"` + UILocalesSupported Locales `json:"ui_locales_supported,omitempty"` // RequestParameterSupported specifies whether the OP supports use of the `request` parameter. If omitted, the default value is false. RequestParameterSupported bool `json:"request_parameter_supported,omitempty"` @@ -149,6 +145,14 @@ type DiscoveryConfiguration struct { // OPTermsOfServiceURI is a URL the OpenID Provider provides to the person registering the Client to read about OpenID Provider's terms of service. OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"` + + // BackChannelLogoutSupported specifies whether the OP supports back-channel logout (https://openid.net/specs/openid-connect-backchannel-1_0.html), + // with true indicating support. If omitted, the default value is false. + BackChannelLogoutSupported bool `json:"backchannel_logout_supported,omitempty"` + + // BackChannelLogoutSessionSupported specifies whether the OP can pass a sid (session ID) Claim in the Logout Token to identify the RP session with the OP. + // If supported, the sid Claim is also included in ID Tokens issued by the OP. If omitted, the default value is false. + BackChannelLogoutSessionSupported bool `json:"backchannel_logout_session_supported,omitempty"` } type AuthMethod string diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go index 79acecd..d93cf44 100644 --- a/pkg/oidc/error.go +++ b/pkg/oidc/error.go @@ -1,8 +1,10 @@ package oidc import ( + "encoding/json" "errors" "fmt" + "log/slog" ) type errorType string @@ -26,6 +28,11 @@ const ( SlowDown errorType = "slow_down" AccessDenied errorType = "access_denied" ExpiredToken errorType = "expired_token" + + // InvalidTarget error is returned by Token Exchange if + // the requested target or audience is invalid. + // [RFC 8693, Section 2.2.2: Error Response](https://www.rfc-editor.org/rfc/rfc8693#section-2.2.2) + InvalidTarget errorType = "invalid_target" ) var ( @@ -111,6 +118,14 @@ var ( Description: "The \"device_code\" has expired.", } } + + // Token exchange error + ErrInvalidTarget = func() *Error { + return &Error{ + ErrorType: InvalidTarget, + Description: "The requested audience or target is invalid.", + } + } ) type Error struct { @@ -118,7 +133,28 @@ type Error struct { ErrorType errorType `json:"error" schema:"error"` Description string `json:"error_description,omitempty" schema:"error_description,omitempty"` State string `json:"state,omitempty" schema:"state,omitempty"` + SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"` redirectDisabled bool `schema:"-"` + returnParent bool `schema:"-"` +} + +func (e *Error) MarshalJSON() ([]byte, error) { + m := struct { + Error errorType `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + State string `json:"state,omitempty"` + SessionState string `json:"session_state,omitempty"` + Parent string `json:"parent,omitempty"` + }{ + Error: e.ErrorType, + ErrorDescription: e.Description, + State: e.State, + SessionState: e.SessionState, + } + if e.returnParent { + m.Parent = e.Parent.Error() + } + return json.Marshal(m) } func (e *Error) Error() string { @@ -143,7 +179,8 @@ func (e *Error) Is(target error) bool { } return e.ErrorType == t.ErrorType && (e.Description == t.Description || t.Description == "") && - (e.State == t.State || t.State == "") + (e.State == t.State || t.State == "") && + (e.SessionState == t.SessionState || t.SessionState == "") } func (e *Error) WithParent(err error) *Error { @@ -151,7 +188,19 @@ func (e *Error) WithParent(err error) *Error { return e } -func (e *Error) WithDescription(desc string, args ...interface{}) *Error { +// WithReturnParentToClient allows returning the set parent error to the HTTP client. +// Currently it only supports setting the parent inside JSON responses, not redirect URLs. +// As Go errors don't unmarshal well, only the marshaller is implemented for the moment. +// +// Warning: parent errors may contain sensitive data or unwanted details about the server status. +// Also, the `parent` field is not a standard error field and might confuse certain clients +// that require fully compliant responses. +func (e *Error) WithReturnParentToClient(b bool) *Error { + e.returnParent = b + return e +} + +func (e *Error) WithDescription(desc string, args ...any) *Error { e.Description = fmt.Sprintf(desc, args...) return e } @@ -171,3 +220,37 @@ func DefaultToServerError(err error, description string) *Error { } return oauth } + +func (e *Error) LogLevel() slog.Level { + level := slog.LevelWarn + if e.ErrorType == ServerError { + level = slog.LevelError + } + if e.ErrorType == AuthorizationPending { + level = slog.LevelInfo + } + return level +} + +func (e *Error) LogValue() slog.Value { + attrs := make([]slog.Attr, 0, 5) + if e.Parent != nil { + attrs = append(attrs, slog.Any("parent", e.Parent)) + } + if e.Description != "" { + attrs = append(attrs, slog.String("description", e.Description)) + } + if e.ErrorType != "" { + attrs = append(attrs, slog.String("type", string(e.ErrorType))) + } + if e.State != "" { + attrs = append(attrs, slog.String("state", e.State)) + } + if e.SessionState != "" { + attrs = append(attrs, slog.String("session_state", e.SessionState)) + } + if e.redirectDisabled { + attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled)) + } + return slog.GroupValue(attrs...) +} diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go new file mode 100644 index 0000000..40d30b1 --- /dev/null +++ b/pkg/oidc/error_test.go @@ -0,0 +1,192 @@ +package oidc + +import ( + "encoding/json" + "errors" + "io" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultToServerError(t *testing.T) { + type args struct { + err error + description string + } + tests := []struct { + name string + args args + want *Error + }{ + { + name: "default", + args: args{ + err: io.ErrClosedPipe, + description: "oops", + }, + want: &Error{ + ErrorType: ServerError, + Description: "oops", + Parent: io.ErrClosedPipe, + }, + }, + { + name: "our Error", + args: args{ + err: ErrAccessDenied(), + description: "oops", + }, + want: &Error{ + ErrorType: AccessDenied, + Description: "The authorization request was denied.", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultToServerError(tt.args.err, tt.args.description) + assert.ErrorIs(t, got, tt.want) + }) + } +} + +func TestError_LogLevel(t *testing.T) { + tests := []struct { + name string + err *Error + want slog.Level + }{ + { + name: "server error", + err: ErrServerError(), + want: slog.LevelError, + }, + { + name: "authorization pending", + err: ErrAuthorizationPending(), + want: slog.LevelInfo, + }, + { + name: "some other error", + err: ErrAccessDenied(), + want: slog.LevelWarn, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.err.LogLevel() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestError_LogValue(t *testing.T) { + type fields struct { + Parent error + ErrorType errorType + Description string + State string + redirectDisabled bool + } + tests := []struct { + name string + fields fields + want slog.Value + }{ + { + name: "parent", + fields: fields{ + Parent: io.EOF, + }, + want: slog.GroupValue(slog.Any("parent", io.EOF)), + }, + { + name: "description", + fields: fields{ + Description: "oops", + }, + want: slog.GroupValue(slog.String("description", "oops")), + }, + { + name: "errorType", + fields: fields{ + ErrorType: ExpiredToken, + }, + want: slog.GroupValue(slog.String("type", string(ExpiredToken))), + }, + { + name: "state", + fields: fields{ + State: "123", + }, + want: slog.GroupValue(slog.String("state", "123")), + }, + { + name: "all fields", + fields: fields{ + Parent: io.EOF, + Description: "oops", + ErrorType: ExpiredToken, + State: "123", + }, + want: slog.GroupValue( + slog.Any("parent", io.EOF), + slog.String("description", "oops"), + slog.String("type", string(ExpiredToken)), + slog.String("state", "123"), + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + Parent: tt.fields.Parent, + ErrorType: tt.fields.ErrorType, + Description: tt.fields.Description, + State: tt.fields.State, + redirectDisabled: tt.fields.redirectDisabled, + } + got := e.LogValue() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestError_MarshalJSON(t *testing.T) { + tests := []struct { + name string + e *Error + want string + }{ + { + name: "simple error", + e: ErrAccessDenied(), + want: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "with description", + e: ErrAccessDenied().WithDescription("oops"), + want: `{"error":"access_denied","error_description":"oops"}`, + }, + { + name: "with parent", + e: ErrServerError().WithParent(errors.New("oops")), + want: `{"error":"server_error"}`, + }, + { + name: "with return parent", + e: ErrServerError().WithParent(errors.New("oops")).WithReturnParentToClient(true), + want: `{"error":"server_error","parent":"oops"}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.e) + require.NoError(t, err) + assert.JSONEq(t, tt.want, string(got)) + }) + } +} diff --git a/pkg/oidc/introspection.go b/pkg/oidc/introspection.go index 8313dc4..1a200eb 100644 --- a/pkg/oidc/introspection.go +++ b/pkg/oidc/introspection.go @@ -16,18 +16,21 @@ type ClientAssertionParams struct { // https://www.rfc-editor.org/rfc/rfc7662.html#section-2.2. // https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims. type IntrospectionResponse struct { - Active bool `json:"active"` - Scope SpaceDelimitedArray `json:"scope,omitempty"` - ClientID string `json:"client_id,omitempty"` - TokenType string `json:"token_type,omitempty"` - Expiration Time `json:"exp,omitempty"` - IssuedAt Time `json:"iat,omitempty"` - NotBefore Time `json:"nbf,omitempty"` - Subject string `json:"sub,omitempty"` - Audience Audience `json:"aud,omitempty"` - Issuer string `json:"iss,omitempty"` - JWTID string `json:"jti,omitempty"` - Username string `json:"username,omitempty"` + Active bool `json:"active"` + Scope SpaceDelimitedArray `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + TokenType string `json:"token_type,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + Issuer string `json:"iss,omitempty"` + JWTID string `json:"jti,omitempty"` + Username string `json:"username,omitempty"` + Actor *ActorClaims `json:"act,omitempty"` UserInfoProfile UserInfoEmail UserInfoPhone diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go index c6e865b..a8b89b0 100644 --- a/pkg/oidc/keyset.go +++ b/pkg/oidc/keyset.go @@ -6,8 +6,9 @@ import ( "crypto/ed25519" "crypto/rsa" "errors" + "strings" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) const ( @@ -46,8 +47,8 @@ func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) { // // will return false none or multiple match // -//deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool -//moved implementation already to FindMatchingKey +// deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool +// moved implementation already to FindMatchingKey func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) { key, err := FindMatchingKey(keyID, use, expectedAlg, keys...) return key, err == nil @@ -91,18 +92,18 @@ func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (k return key, ErrKeyNone } -func algToKeyType(key interface{}, alg string) bool { - switch alg[0] { - case 'R', 'P': +func algToKeyType(key any, alg string) bool { + if strings.HasPrefix(alg, "RS") || strings.HasPrefix(alg, "PS") { _, ok := key.(*rsa.PublicKey) return ok - case 'E': + } + if strings.HasPrefix(alg, "ES") { _, ok := key.(*ecdsa.PublicKey) return ok - case 'O': - _, ok := key.(*ed25519.PublicKey) - return ok - default: - return false } + if alg == string(jose.EdDSA) { + _, ok := key.(ed25519.PublicKey) + return ok + } + return false } diff --git a/pkg/oidc/keyset_test.go b/pkg/oidc/keyset_test.go index 82b3ee8..e01074e 100644 --- a/pkg/oidc/keyset_test.go +++ b/pkg/oidc/keyset_test.go @@ -7,7 +7,7 @@ import ( "reflect" "testing" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) func TestFindKey(t *testing.T) { diff --git a/pkg/oidc/regression_test.go b/pkg/oidc/regression_test.go index 5d33bb6..9cb3ff9 100644 --- a/pkg/oidc/regression_test.go +++ b/pkg/oidc/regression_test.go @@ -17,7 +17,7 @@ const dataDir = "regression_data" // jsonFilename builds a filename for the regression testdata. // dataDir/.json -func jsonFilename(obj interface{}) string { +func jsonFilename(obj any) string { name := fmt.Sprintf("%T.json", obj) return path.Join( dataDir, @@ -25,13 +25,13 @@ func jsonFilename(obj interface{}) string { ) } -func encodeJSON(t *testing.T, w io.Writer, obj interface{}) { +func encodeJSON(t *testing.T, w io.Writer, obj any) { enc := json.NewEncoder(w) enc.SetIndent("", "\t") require.NoError(t, enc.Encode(obj)) } -var regressionData = []interface{}{ +var regressionData = []any{ accessTokenData, idTokenData, introspectionResponseData, diff --git a/pkg/oidc/session.go b/pkg/oidc/session.go index b470d1e..39f9f08 100644 --- a/pkg/oidc/session.go +++ b/pkg/oidc/session.go @@ -1,10 +1,12 @@ package oidc // EndSessionRequest for the RP-Initiated Logout according to: -//https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout +// https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout type EndSessionRequest struct { - IdTokenHint string `schema:"id_token_hint"` - ClientID string `schema:"client_id"` - PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"` - State string `schema:"state"` + IdTokenHint string `schema:"id_token_hint"` + LogoutHint string `schema:"logout_hint"` + ClientID string `schema:"client_id"` + PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"` + State string `schema:"state"` + UILocales Locales `schema:"ui_locales"` } diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go index 5283eb5..4b43dcb 100644 --- a/pkg/oidc/token.go +++ b/pkg/oidc/token.go @@ -5,11 +5,12 @@ import ( "os" "time" + jose "github.com/go-jose/go-jose/v4" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" "github.com/muhlemmer/gu" - "github.com/zitadel/oidc/v2/pkg/crypto" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" ) const ( @@ -34,19 +35,20 @@ type Tokens[C IDClaims] struct { // TokenClaims implements the Claims interface, // and can be used to extend larger claim types by embedding. type TokenClaims struct { - Issuer string `json:"iss,omitempty"` - Subject string `json:"sub,omitempty"` - Audience Audience `json:"aud,omitempty"` - Expiration Time `json:"exp,omitempty"` - IssuedAt Time `json:"iat,omitempty"` - AuthTime Time `json:"auth_time,omitempty"` - NotBefore Time `json:"nbf,omitempty"` - Nonce string `json:"nonce,omitempty"` - AuthenticationContextClassReference string `json:"acr,omitempty"` - AuthenticationMethodsReferences []string `json:"amr,omitempty"` - AuthorizedParty string `json:"azp,omitempty"` - ClientID string `json:"client_id,omitempty"` - JWTID string `json:"jti,omitempty"` + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + Expiration Time `json:"exp,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + AuthTime Time `json:"auth_time,omitempty"` + NotBefore Time `json:"nbf,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + ClientID string `json:"client_id,omitempty"` + JWTID string `json:"jti,omitempty"` + Actor *ActorClaims `json:"act,omitempty"` // Additional information set by this framework SignatureAlg jose.SignatureAlgorithm `json:"-"` @@ -115,6 +117,7 @@ func NewAccessTokenClaims(issuer, subject string, audience []string, expiration Expiration: FromTime(expiration), IssuedAt: FromTime(now), NotBefore: FromTime(now), + ClientID: clientID, JWTID: jwtid, }, } @@ -204,13 +207,36 @@ func (i *IDTokenClaims) UnmarshalJSON(data []byte) error { return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims) } +// ActorClaims provides the `act` claims used for impersonation or delegation Token Exchange. +// +// An actor can be nested in case an obtained token is used as actor token to obtain impersonation or delegation. +// This allows creating a chain of actors. +// See [RFC 8693, section 4.1](https://www.rfc-editor.org/rfc/rfc8693#name-act-actor-claim). +type ActorClaims struct { + Actor *ActorClaims `json:"act,omitempty"` + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Claims map[string]any `json:"-"` +} + +type acAlias ActorClaims + +func (c *ActorClaims) MarshalJSON() ([]byte, error) { + return mergeAndMarshalClaims((*acAlias)(c), c.Claims) +} + +func (c *ActorClaims) UnmarshalJSON(data []byte) error { + return unmarshalJSONMulti(data, (*acAlias)(c), &c.Claims) +} + type AccessTokenResponse struct { - AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` - ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` - IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` - State string `json:"state,omitempty" schema:"state,omitempty"` + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` + State string `json:"state,omitempty" schema:"state,omitempty"` + Scope SpaceDelimitedArray `json:"scope,omitempty" schema:"scope,omitempty"` } type JWTProfileAssertionClaims struct { @@ -222,7 +248,7 @@ type JWTProfileAssertionClaims struct { Expiration Time `json:"exp"` IssuedAt Time `json:"iat"` - Claims map[string]interface{} `json:"-"` + Claims map[string]any `json:"-"` } type jpaAlias JWTProfileAssertionClaims @@ -262,7 +288,7 @@ func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) { } } -func JWTProfileCustomClaim(key string, value interface{}) func(*JWTProfileAssertionClaims) { +func JWTProfileCustomClaim(key string, value any) func(*JWTProfileAssertionClaims) { return func(j *JWTProfileAssertionClaims) { j.Claims[key] = value } @@ -292,7 +318,7 @@ func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, IssuedAt: FromTime(time.Now().UTC()), Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()), Audience: audience, - Claims: make(map[string]interface{}), + Claims: make(map[string]any), } for _, opt := range opts { @@ -321,12 +347,12 @@ func AppendClientIDToAudience(clientID string, audience []string) []string { } func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) { - privateKey, err := crypto.BytesToPrivateKey(assertion.PrivateKey) + privateKey, algorithm, err := crypto.BytesToPrivateKey(assertion.PrivateKey) if err != nil { return "", err } key := jose.SigningKey{ - Algorithm: jose.RS256, + Algorithm: algorithm, Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID}, } signer, err := jose.NewSigner(key, &jose.SignerOptions{}) @@ -352,4 +378,45 @@ type TokenExchangeResponse struct { ExpiresIn uint64 `json:"expires_in,omitempty"` Scopes SpaceDelimitedArray `json:"scope,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` + + // IDToken field allows returning an additional ID token + // if the requested_token_type was Access Token and scope contained openid. + IDToken string `json:"id_token,omitempty"` +} + +type LogoutTokenClaims struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audience Audience `json:"aud,omitempty"` + IssuedAt Time `json:"iat,omitempty"` + Expiration Time `json:"exp,omitempty"` + JWTID string `json:"jti,omitempty"` + Events map[string]any `json:"events,omitempty"` + SessionID string `json:"sid,omitempty"` + Claims map[string]any `json:"-"` +} + +type ltcAlias LogoutTokenClaims + +func (i *LogoutTokenClaims) MarshalJSON() ([]byte, error) { + return mergeAndMarshalClaims((*ltcAlias)(i), i.Claims) +} + +func (i *LogoutTokenClaims) UnmarshalJSON(data []byte) error { + return unmarshalJSONMulti(data, (*ltcAlias)(i), &i.Claims) +} + +func NewLogoutTokenClaims(issuer, subject string, audience Audience, expiration time.Time, jwtID, sessionID string, skew time.Duration) *LogoutTokenClaims { + return &LogoutTokenClaims{ + Issuer: issuer, + Subject: subject, + Audience: audience, + IssuedAt: FromTime(time.Now().Add(-skew)), + Expiration: FromTime(expiration), + JWTID: jwtID, + Events: map[string]any{ + "http://schemas.openid.net/event/backchannel-logout": struct{}{}, + }, + SessionID: sessionID, + } } diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 6b6945a..dadb205 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -3,9 +3,10 @@ package oidc import ( "encoding/json" "fmt" + "slices" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) const ( @@ -57,13 +58,7 @@ var AllTokenTypes = []TokenType{ type TokenType string func (t TokenType) IsSupported() bool { - for _, tt := range AllTokenTypes { - if t == tt { - return true - } - } - - return false + return slices.Contains(AllTokenTypes, t) } type TokenRequest interface { @@ -77,10 +72,10 @@ type AccessTokenRequest struct { Code string `schema:"code"` RedirectURI string `schema:"redirect_uri"` ClientID string `schema:"client_id"` - ClientSecret string `schema:"client_secret"` - CodeVerifier string `schema:"code_verifier"` - ClientAssertion string `schema:"client_assertion"` - ClientAssertionType string `schema:"client_assertion_type"` + ClientSecret string `schema:"client_secret,omitempty"` + CodeVerifier string `schema:"code_verifier,omitempty"` + ClientAssertion string `schema:"client_assertion,omitempty"` + ClientAssertionType string `schema:"client_assertion_type,omitempty"` } func (a *AccessTokenRequest) GrantType() GrantType { @@ -130,7 +125,7 @@ type JWTTokenRequest struct { IssuedAt Time `json:"iat"` ExpiresAt Time `json:"exp"` - private map[string]interface{} + private map[string]any } func (j *JWTTokenRequest) MarshalJSON() ([]byte, error) { @@ -171,7 +166,7 @@ func (j *JWTTokenRequest) UnmarshalJSON(data []byte) error { return nil } -func (j *JWTTokenRequest) GetCustomClaim(key string) interface{} { +func (j *JWTTokenRequest) GetCustomClaim(key string) any { return j.private[key] } @@ -241,7 +236,7 @@ type TokenExchangeRequest struct { } type ClientCredentialsRequest struct { - GrantType GrantType `schema:"grant_type"` + GrantType GrantType `schema:"grant_type,omitempty"` Scope SpaceDelimitedArray `schema:"scope"` ClientID string `schema:"client_id"` ClientSecret string `schema:"client_secret"` diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go index ef1e77f..621cdbc 100644 --- a/pkg/oidc/token_test.go +++ b/pkg/oidc/token_test.go @@ -4,9 +4,9 @@ import ( "testing" "time" + jose "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/assert" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) var ( @@ -29,7 +29,7 @@ var ( accessTokenData = &AccessTokenClaims{ TokenClaims: tokenClaimsData, Scopes: []string{"email", "phone"}, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -43,7 +43,7 @@ var ( UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, Address: userInfoData.Address, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -64,7 +64,7 @@ var ( UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, Address: userInfoData.Address, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -102,7 +102,7 @@ var ( PostalCode: "666-666", Country: "Moon", }, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -114,7 +114,7 @@ var ( Audience: Audience{"foo", "bar"}, Expiration: 12345, IssuedAt: 12000, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -145,6 +145,7 @@ func TestNewAccessTokenClaims(t *testing.T) { Subject: "hello@me.com", Audience: Audience{"foo"}, Expiration: 12345, + ClientID: "foo", JWTID: "900", }, } @@ -181,7 +182,7 @@ func TestIDTokenClaims_SetUserInfo(t *testing.T) { UserInfoEmail: userInfoData.UserInfoEmail, UserInfoPhone: userInfoData.UserInfoPhone, Address: userInfoData.Address, - Claims: map[string]interface{}{ + Claims: map[string]any{ "foo": "bar", }, } @@ -241,3 +242,39 @@ func TestIDTokenClaims_GetUserInfo(t *testing.T) { got := idTokenData.GetUserInfo() assert.Equal(t, want, got) } + +func TestNewLogoutTokenClaims(t *testing.T) { + want := &LogoutTokenClaims{ + Issuer: "zitadel", + Subject: "hello@me.com", + Audience: Audience{"foo", "just@me.com"}, + Expiration: 12345, + JWTID: "jwtID", + Events: map[string]any{ + "http://schemas.openid.net/event/backchannel-logout": struct{}{}, + }, + SessionID: "sessionID", + Claims: nil, + } + + got := NewLogoutTokenClaims( + want.Issuer, + want.Subject, + want.Audience, + want.Expiration.AsTime(), + want.JWTID, + want.SessionID, + 1*time.Second, + ) + + // test if the dynamic timestamp is around now, + // allowing for a delta of 1, just in case we flip on + // either side of a second boundry. + nowMinusSkew := NowTime() - 1 + assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1) + + // Make equal not fail on dynamic timestamp + got.IssuedAt = 0 + + assert.Equal(t, want, got) +} diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index 167f8b7..5d063b1 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -3,26 +3,28 @@ package oidc import ( "database/sql/driver" "encoding/json" + "errors" "fmt" "reflect" "strings" "time" - "github.com/gorilla/schema" + jose "github.com/go-jose/go-jose/v4" + "github.com/muhlemmer/gu" + "github.com/zitadel/schema" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" ) type Audience []string func (a *Audience) UnmarshalJSON(text []byte) error { - var i interface{} + var i any err := json.Unmarshal(text, &i) if err != nil { return err } switch aud := i.(type) { - case []interface{}: + case []any: *a = make([]string, len(aud)) for i, audience := range aud { (*a)[i] = audience.(string) @@ -33,6 +35,17 @@ func (a *Audience) UnmarshalJSON(text []byte) error { return nil } +func (a *Audience) MarshalJSON() ([]byte, error) { + len := len(*a) + if len > 1 { + return json.Marshal(*a) + } else if len == 1 { + return json.Marshal((*a)[0]) + } + + return nil, errors.New("aud is empty") +} + type Display string func (d *Display) UnmarshalText(text []byte) error { @@ -75,20 +88,90 @@ func (l *Locale) MarshalJSON() ([]byte, error) { return json.Marshal(tag) } +// UnmarshalJSON implements json.Unmarshaler. +// When [language.ValueError] is encountered, the containing tag will be set +// to an empty value (language "und") and no error will be returned. +// This state can be checked with the `l.Tag().IsRoot()` method. func (l *Locale) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &l.tag) + if len(data) == 0 || string(data) == "\"\"" { + return nil + } + err := json.Unmarshal(data, &l.tag) + if err == nil { + return nil + } + + // catch "well-formed but unknown" errors + var target language.ValueError + if errors.As(err, &target) { + l.tag = language.Tag{} + return nil + } + return err } type Locales []language.Tag -func (l *Locales) UnmarshalText(text []byte) error { - locales := strings.Split(string(text), " ") +// ParseLocales parses a slice of strings into Locales. +// If an entry causes a parse error or is undefined, +// it is ignored and not set to Locales. +func ParseLocales(locales []string) Locales { + out := make(Locales, 0, len(locales)) for _, locale := range locales { tag, err := language.Parse(locale) if err == nil && !tag.IsRoot() { - *l = append(*l, tag) + out = append(out, tag) } } + return out +} + +func (l Locales) String() string { + tags := make([]string, len(l)) + for i, tag := range l { + tags[i] = tag.String() + } + return strings.Join(tags, " ") +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +// It decodes an unquoted space seperated string into Locales. +// Undefined language tags in the input are ignored and ommited from +// the resulting Locales. +func (l *Locales) UnmarshalText(text []byte) error { + *l = ParseLocales( + strings.Split(string(text), " "), + ) + return nil +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +// It decodes a json array or a space seperated string into Locales. +// Undefined language tags in the input are ignored and ommited from +// the resulting Locales. +func (l *Locales) UnmarshalJSON(data []byte) error { + var dst any + if err := json.Unmarshal(data, &dst); err != nil { + return fmt.Errorf("oidc locales: %w", err) + } + + // We catch the posibility of a space seperated string here, + // because UnmarshalText might have been implicetely called + // by the json library before we added UnmarshalJSON. + switch v := dst.(type) { + case nil: + *l = nil + case string: + *l = ParseLocales(strings.Split(v, " ")) + case []any: + locales, err := gu.AssertInterfaces[string](v) + if err != nil { + return fmt.Errorf("oidc locales: %w", err) + } + *l = ParseLocales(locales) + default: + return fmt.Errorf("oidc locales: unsupported type: %T", v) + } return nil } @@ -106,7 +189,7 @@ type ResponseType string type ResponseMode string -func (s SpaceDelimitedArray) Encode() string { +func (s SpaceDelimitedArray) String() string { return strings.Join(s, " ") } @@ -116,11 +199,11 @@ func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error { } func (s SpaceDelimitedArray) MarshalText() ([]byte, error) { - return []byte(s.Encode()), nil + return []byte(s.String()), nil } func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) { - return json.Marshal((s).Encode()) + return json.Marshal((s).String()) } func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { @@ -132,7 +215,7 @@ func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { return nil } -func (s *SpaceDelimitedArray) Scan(src interface{}) error { +func (s *SpaceDelimitedArray) Scan(src any) error { if src == nil { *s = nil return nil @@ -165,7 +248,10 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) { func NewEncoder() *schema.Encoder { e := schema.NewEncoder() e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string { - return value.Interface().(SpaceDelimitedArray).Encode() + return value.Interface().(SpaceDelimitedArray).String() + }) + e.RegisterEncoder(Locales{}, func(value reflect.Value) string { + return value.Interface().(Locales).String() }) return e } diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index 64f07f1..53a9779 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/schema" "golang.org/x/text/language" ) @@ -208,20 +208,78 @@ func TestLocale_MarshalJSON(t *testing.T) { } func TestLocale_UnmarshalJSON(t *testing.T) { - type a struct { + type dst struct { Locale *Locale `json:"locale,omitempty"` } - want := a{ - Locale: NewLocale(language.Afrikaans), + tests := []struct { + name string + input string + want dst + wantErr bool + }{ + { + name: "value not present", + input: `{}`, + wantErr: false, + want: dst{ + Locale: nil, + }, + }, + { + name: "null", + input: `{"locale": null}`, + wantErr: false, + want: dst{ + Locale: nil, + }, + }, + { + name: "empty, ignored", + input: `{"locale": ""}`, + wantErr: false, + want: dst{ + Locale: &Locale{}, + }, + }, + { + name: "afrikaans, ok", + input: `{"locale": "af"}`, + want: dst{ + Locale: NewLocale(language.Afrikaans), + }, + }, + { + name: "gb, ignored", + input: `{"locale": "gb"}`, + want: dst{ + Locale: &Locale{}, + }, + }, + { + name: "bad form, error", + input: `{"locale": "g!!!!!"}`, + wantErr: true, + }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got dst + err := json.Unmarshal([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} - const input = `{"locale": "af"}` - var got a - - require.NoError(t, - json.Unmarshal([]byte(input), &got), - ) - assert.Equal(t, want, got) +func TestParseLocales(t *testing.T) { + in := []string{language.Afrikaans.String(), language.Danish.String(), "foobar", language.Und.String()} + want := Locales{language.Afrikaans, language.Danish} + got := ParseLocales(in) + assert.ElementsMatch(t, want, got) } func TestLocales_UnmarshalText(t *testing.T) { @@ -281,6 +339,80 @@ func TestLocales_UnmarshalText(t *testing.T) { } } +func TestLocales_UnmarshalJSON(t *testing.T) { + in := []string{language.Afrikaans.String(), language.Danish.String(), "foobar", language.Und.String()} + spaceSepStr := strconv.Quote(strings.Join(in, " ")) + jsonArray, err := json.Marshal(in) + require.NoError(t, err) + + out := Locales{language.Afrikaans, language.Danish} + + type args struct { + data []byte + } + tests := []struct { + name string + args args + want Locales + wantErr bool + }{ + { + name: "invalid JSON", + args: args{ + data: []byte("~~~"), + }, + wantErr: true, + }, + { + name: "null", + args: args{ + data: []byte("null"), + }, + want: nil, + }, + { + name: "space seperated string", + args: args{ + data: []byte(spaceSepStr), + }, + want: out, + }, + { + name: "json string array", + args: args{ + data: jsonArray, + }, + want: out, + }, + { + name: "json invalid array", + args: args{ + data: []byte(`[1,2,3]`), + }, + wantErr: true, + }, + { + name: "invalid type (float64)", + args: args{ + data: []byte("22"), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Locales + err := got.UnmarshalJSON([]byte(tt.args.data)) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + func TestScopes_UnmarshalText(t *testing.T) { type args struct { text []byte diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go index caff58e..ef8ebe4 100644 --- a/pkg/oidc/userinfo.go +++ b/pkg/oidc/userinfo.go @@ -29,6 +29,11 @@ func (u *UserInfo) GetAddress() *UserInfoAddress { return u.Address } +// GetSubject implements [rp.SubjectGetter] +func (u *UserInfo) GetSubject() string { + return u.Subject +} + type uiAlias UserInfo func (u *UserInfo) MarshalJSON() ([]byte, error) { diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go index c4ee95e..d5e0213 100644 --- a/pkg/oidc/verifier.go +++ b/pkg/oidc/verifier.go @@ -7,12 +7,11 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "time" - "gopkg.in/square/go-jose.v2" - - str "github.com/zitadel/oidc/v2/pkg/strings" + jose "github.com/go-jose/go-jose/v4" ) type Claims interface { @@ -41,6 +40,7 @@ type IDClaims interface { var ( ErrParse = errors.New("parsing of request failed") ErrIssuerInvalid = errors.New("issuer does not match") + ErrDiscoveryFailed = errors.New("OpenID Provider Configuration Discovery has failed") ErrSubjectMissing = errors.New("subject missing") ErrAudience = errors.New("audience is not valid") ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty") @@ -57,14 +57,23 @@ var ( ErrNonceInvalid = errors.New("nonce does not match") ErrAcrInvalid = errors.New("acr is invalid") ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") - ErrAuthTimeToOld = errors.New("auth time of token is to old") + ErrAuthTimeToOld = errors.New("auth time of token is too old") ErrAtHash = errors.New("at_hash does not correspond to access token") ) -type Verifier interface { - Issuer() string - MaxAgeIAT() time.Duration - Offset() time.Duration +// Verifier caries configuration for the various token verification +// functions. Use package specific constructor functions to know +// which values need to be set. +type Verifier struct { + Issuer string + MaxAgeIAT time.Duration + Offset time.Duration + ClientID string + SupportedSignAlgs []string + MaxAge time.Duration + ACR ACRVerifier + KeySet KeySet + Nonce func(ctx context.Context) string } // ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim @@ -74,7 +83,7 @@ type ACRVerifier func(string) error // if none of the provided values matches the acr claim func DefaultACRVerifier(possibleValues []string) ACRVerifier { return func(acr string) error { - if !str.Contains(possibleValues, acr) { + if !slices.Contains(possibleValues, acr) { return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr) } return nil @@ -85,7 +94,7 @@ func DecryptToken(tokenString string) (string, error) { return tokenString, nil // TODO: impl } -func ParseToken(tokenString string, claims interface{}) ([]byte, error) { +func ParseToken(tokenString string, claims any) ([]byte, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse) @@ -113,7 +122,7 @@ func CheckIssuer(claims Claims, issuer string) error { } func CheckAudience(claims Claims, clientID string) error { - if !str.Contains(claims.GetAudience(), clientID) { + if !slices.Contains(claims.GetAudience(), clientID) { return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID) } @@ -121,6 +130,11 @@ func CheckAudience(claims Claims, clientID string) error { return nil } +// CheckAuthorizedParty checks azp (authorized party) claim requirements. +// +// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present. +// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value. +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation func CheckAuthorizedParty(claims Claims, clientID string) error { if len(claims.GetAudience()) > 1 { if claims.GetAuthorizedParty() == "" { @@ -134,8 +148,13 @@ func CheckAuthorizedParty(claims Claims, clientID string) error { } func CheckSignature(ctx context.Context, token string, payload []byte, claims ClaimsSignature, supportedSigAlgs []string, set KeySet) error { - jws, err := jose.ParseSigned(token) + jws, err := jose.ParseSigned(token, toJoseSignatureAlgorithms(supportedSigAlgs)) if err != nil { + if strings.HasPrefix(err.Error(), "go-jose/go-jose: unexpected signature algorithm") { + // TODO(v4): we should wrap errors instead of returning static ones. + // This is a workaround so we keep returning the same error for now. + return ErrSignatureUnsupportedAlg + } return ErrParse } if len(jws.Signatures) == 0 { @@ -145,12 +164,6 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl return ErrSignatureMultiple } sig := jws.Signatures[0] - if len(supportedSigAlgs) == 0 { - supportedSigAlgs = []string{"RS256"} - } - if !str.Contains(supportedSigAlgs, sig.Header.Algorithm) { - return fmt.Errorf("%w: id token signed with unsupported algorithm, expected %q got %q", ErrSignatureUnsupportedAlg, supportedSigAlgs, sig.Header.Algorithm) - } signedPayload, err := set.VerifySignature(ctx, jws) if err != nil { @@ -166,27 +179,39 @@ func CheckSignature(ctx context.Context, token string, payload []byte, claims Cl return nil } +// TODO(v4): Use the new jose.SignatureAlgorithm type directly, instead of string. +func toJoseSignatureAlgorithms(algorithms []string) []jose.SignatureAlgorithm { + out := make([]jose.SignatureAlgorithm, len(algorithms)) + for i := range algorithms { + out[i] = jose.SignatureAlgorithm(algorithms[i]) + } + if len(out) == 0 { + out = append(out, jose.RS256, jose.ES256, jose.PS256) + } + return out +} + func CheckExpiration(claims Claims, offset time.Duration) error { - expiration := claims.GetExpiration().Round(time.Second) - if !time.Now().UTC().Add(offset).Before(expiration) { + expiration := claims.GetExpiration() + if !time.Now().Add(offset).Before(expiration) { return ErrExpired } return nil } func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error { - issuedAt := claims.GetIssuedAt().Round(time.Second) + issuedAt := claims.GetIssuedAt() if issuedAt.IsZero() { return ErrIatMissing } - nowWithOffset := time.Now().UTC().Add(offset).Round(time.Second) + nowWithOffset := time.Now().Add(offset).Round(time.Second) if issuedAt.After(nowWithOffset) { return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset) } if maxAgeIAT == 0 { return nil } - maxAge := time.Now().UTC().Add(-maxAgeIAT).Round(time.Second) + maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second) if issuedAt.Before(maxAge) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt)) } @@ -216,8 +241,8 @@ func CheckAuthTime(claims Claims, maxAge time.Duration) error { if claims.GetAuthTime().IsZero() { return ErrAuthTimeNotPresent } - authTime := claims.GetAuthTime().Round(time.Second) - maxAuthTime := time.Now().UTC().Add(-maxAge).Round(time.Second) + authTime := claims.GetAuthTime() + maxAuthTime := time.Now().Add(-maxAge).Round(time.Second) if authTime.Before(maxAuthTime) { return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime)) } diff --git a/pkg/oidc/verifier_parse_test.go b/pkg/oidc/verifier_parse_test.go new file mode 100644 index 0000000..9cf5c1e --- /dev/null +++ b/pkg/oidc/verifier_parse_test.go @@ -0,0 +1,128 @@ +package oidc_test + +import ( + "context" + "encoding/json" + "testing" + + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseToken(t *testing.T) { + token, wantClaims := tu.ValidIDToken() + wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload + + wantPayload, err := json.Marshal(wantClaims) + require.NoError(t, err) + + tests := []struct { + name string + tokenString string + wantErr bool + }{ + { + name: "split error", + tokenString: "nope", + wantErr: true, + }, + { + name: "base64 error", + tokenString: "foo.~.bar", + wantErr: true, + }, + { + name: "success", + tokenString: token, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClaims := new(oidc.IDTokenClaims) + gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, wantClaims, gotClaims) + assert.JSONEq(t, string(wantPayload), string(gotPayload)) + }) + } +} + +func TestCheckSignature(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + token, _ := tu.ValidIDToken() + payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{}) + require.NoError(t, err) + + type args struct { + ctx context.Context + token string + payload []byte + supportedSigAlgs []string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "parse error", + args: args{ + ctx: context.Background(), + token: "~", + payload: payload, + }, + wantErr: oidc.ErrParse, + }, + { + name: "default sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + }, + }, + { + name: "unsupported sigAlg", + args: args{ + ctx: context.Background(), + token: token, + payload: payload, + supportedSigAlgs: []string{"foo", "bar"}, + }, + wantErr: oidc.ErrSignatureUnsupportedAlg, + }, + { + name: "verify error", + args: args{ + ctx: errCtx, + token: token, + payload: payload, + }, + wantErr: oidc.ErrSignatureInvalid, + }, + { + name: "inequal payloads", + args: args{ + ctx: context.Background(), + token: token, + payload: []byte{0, 1, 2}, + }, + wantErr: oidc.ErrSignatureInvalidPayload, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := new(oidc.TokenClaims) + err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{}) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/oidc/verifier_test.go b/pkg/oidc/verifier_test.go new file mode 100644 index 0000000..93e7157 --- /dev/null +++ b/pkg/oidc/verifier_test.go @@ -0,0 +1,374 @@ +package oidc + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecryptToken(t *testing.T) { + const tokenString = "ABC" + got, err := DecryptToken(tokenString) + require.NoError(t, err) + assert.Equal(t, tokenString, got) +} + +func TestDefaultACRVerifier(t *testing.T) { + acrVerfier := DefaultACRVerifier([]string{"foo", "bar"}) + + tests := []struct { + name string + acr string + wantErr string + }{ + { + name: "ok", + acr: "bar", + }, + { + name: "error", + acr: "hello", + wantErr: "expected one of: [foo bar], got: \"hello\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := acrVerfier(tt.acr) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + }) + } +} + +func TestCheckSubject(t *testing.T) { + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrSubjectMissing, + }, + { + name: "ok", + claims: &TokenClaims{ + Subject: "foo", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSubject(tt.claims) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuer(t *testing.T) { + const issuer = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIssuerInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Issuer: "wrong", + }, + wantErr: ErrIssuerInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Issuer: issuer, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuer(tt.claims, issuer) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAudience(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrAudience, + }, + { + name: "wrong", + claims: &TokenClaims{ + Audience: []string{"wrong"}, + }, + wantErr: ErrAudience, + }, + { + name: "ok", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAudience(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizedParty(t *testing.T) { + const clientID = "foo.bar" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "single audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + }, + }, + { + name: "multiple audience, no azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + }, + wantErr: ErrAzpMissing, + }, + { + name: "single audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID}, + AuthorizedParty: clientID, + }, + }, + { + name: "multiple audience, with azp", + claims: &TokenClaims{ + Audience: []string{clientID, "other"}, + AuthorizedParty: clientID, + }, + }, + { + name: "wrong azp", + claims: &TokenClaims{ + AuthorizedParty: "wrong", + }, + wantErr: ErrAzpInvalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizedParty(tt.claims, clientID) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckExpiration(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrExpired, + }, + { + name: "expired", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(-2 * offset)), + }, + wantErr: ErrExpired, + }, + { + name: "valid", + claims: &TokenClaims{ + Expiration: FromTime(time.Now().Add(2 * offset)), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckExpiration(tt.claims, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckIssuedAt(t *testing.T) { + const offset = time.Minute + tests := []struct { + name string + maxAgeIAT time.Duration + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrIatMissing, + }, + { + name: "future", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(time.Hour)), + }, + wantErr: ErrIatInFuture, + }, + { + name: "no max", + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + { + name: "past max", + maxAgeIAT: time.Minute, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrIatToOld, + }, + { + name: "within max", + maxAgeIAT: time.Hour, + claims: &TokenClaims{ + IssuedAt: FromTime(time.Now()), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckNonce(t *testing.T) { + const nonce = "123" + tests := []struct { + name string + claims Claims + wantErr error + }{ + { + name: "missing", + claims: &TokenClaims{}, + wantErr: ErrNonceInvalid, + }, + { + name: "wrong", + claims: &TokenClaims{ + Nonce: "wrong", + }, + wantErr: ErrNonceInvalid, + }, + { + name: "ok", + claims: &TokenClaims{ + Nonce: nonce, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckNonce(tt.claims, nonce) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthorizationContextClassReference(t *testing.T) { + tests := []struct { + name string + acr ACRVerifier + wantErr error + }{ + { + name: "error", + acr: func(s string) error { return errors.New("oops") }, + wantErr: ErrAcrInvalid, + }, + { + name: "ok", + acr: func(s string) error { return nil }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckAuthTime(t *testing.T) { + tests := []struct { + name string + claims Claims + maxAge time.Duration + wantErr error + }{ + { + name: "no max age", + claims: &TokenClaims{}, + }, + { + name: "missing", + claims: &TokenClaims{}, + maxAge: time.Minute, + wantErr: ErrAuthTimeNotPresent, + }, + { + name: "expired", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: FromTime(time.Now().Add(-time.Hour)), + }, + wantErr: ErrAuthTimeToOld, + }, + { + name: "ok", + maxAge: time.Minute, + claims: &TokenClaims{ + AuthTime: NowTime(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckAuthTime(tt.claims, tt.maxAge) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go index 1f9fc45..b1434cc 100644 --- a/pkg/op/auth_request.go +++ b/pkg/op/auth_request.go @@ -1,20 +1,23 @@ package op import ( + "bytes" "context" + _ "embed" + "errors" "fmt" + "html/template" + "log/slog" "net" "net/http" "net/url" - "path" + "slices" "strings" "time" - "github.com/gorilla/mux" - - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - str "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/bmatcuk/doublestar/v4" ) type AuthRequest interface { @@ -35,20 +38,34 @@ type AuthRequest interface { Done() bool } +// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported +type AuthRequestSessionState interface { + // GetSessionState returns session_state. + // session_state is related to OpenID Connect Session Management. + GetSessionState() string +} + type Authorizer interface { Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier Crypto() Crypto RequestObjectSupported() bool + Logger() *slog.Logger } // AuthorizeValidator is an extension of Authorizer interface // implementing its own validation mechanism for the auth request type AuthorizeValidator interface { Authorizer - ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, IDTokenHintVerifier) (string, error) + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error) +} + +type CodeResponseType struct { + Code string `schema:"code"` + State string `schema:"state,omitempty"` + SessionState string `schema:"session_state,omitempty"` } func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { @@ -57,7 +74,7 @@ func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Req } } -func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { +func AuthorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { AuthorizeCallback(w, r, authorizer) } @@ -66,48 +83,54 @@ func authorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, * // Authorize handles the authorization request, including // parsing, validating, storing and finally redirecting to the login handler func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + ctx, span := tracer.Start(r.Context(), "Authorize") + r = r.WithContext(ctx) + defer span.End() + authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } - ctx := r.Context() if authReq.RequestParam != "" && authorizer.RequestObjectSupported() { - authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) + err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } } if authReq.ClientID == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder()) + AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing client_id"), authorizer) return } if authReq.RedirectURI == "" { - AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder()) + AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing redirect_uri"), authorizer) return } - validation := ValidateAuthRequest - if validater, ok := authorizer.(AuthorizeValidator); ok { - validation = validater.ValidateAuthRequest + + var client Client + validation := func(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { + client, err = authorizer.Storage().GetClientByClientID(ctx, authReq.ClientID) + if err != nil { + return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err) + } + return ValidateAuthRequestClient(ctx, authReq, client, verifier) + } + if validator, ok := authorizer.(AuthorizeValidator); ok { + validation = validator.ValidateAuthRequest } userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx)) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.RequestParam != "" { - AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer) return } req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID) if err != nil { - AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder()) - return - } - client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) - if err != nil { - AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder()) + AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer) return } RedirectToLogin(req.GetID(), client, w, r) @@ -129,37 +152,37 @@ func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.A // ParseRequestObject parse the `request` parameter, validates the token including the signature // and copies the token claims into the auth request -func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) (*oidc.AuthRequest, error) { +func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error { requestObject := new(oidc.RequestObject) payload, err := oidc.ParseToken(authReq.RequestParam, requestObject) if err != nil { - return nil, err + return err } if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request") } if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request") } if requestObject.Issuer != requestObject.ClientID { - return authReq, oidc.ErrInvalidRequest() + return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request") } - if !str.Contains(requestObject.Audience, issuer) { - return authReq, oidc.ErrInvalidRequest() + if !slices.Contains(requestObject.Audience, issuer) { + return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience") } keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer} if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil { - return authReq, err + return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error()) } CopyRequestObjectToAuthRequest(authReq, requestObject) - return authReq, nil + return nil } // CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request // and clears the `RequestParam` of the auth request func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) { - if str.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 { + if slices.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 { authReq.Scopes = requestObject.Scopes } if requestObject.RedirectURI != "" { @@ -204,23 +227,37 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi authReq.RequestParam = "" } -// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed -func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) { +// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed. +// +// Deprecated: Use [ValidateAuthRequestClient] to prevent querying for the Client twice. +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) { + ctx, span := tracer.Start(ctx, "ValidateAuthRequest") + defer span.End() + + client, err := storage.GetClientByClientID(ctx, authReq.ClientID) + if err != nil { + return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err) + } + return ValidateAuthRequestClient(ctx, authReq, client, verifier) +} + +// ValidateAuthRequestClient validates the Auth request against the passed client. +// If id_token_hint is part of the request, the subject of the token is returned. +func ValidateAuthRequestClient(ctx context.Context, authReq *oidc.AuthRequest, client Client, verifier *IDTokenHintVerifier) (sub string, err error) { + ctx, span := tracer.Start(ctx, "ValidateAuthRequestClient") + defer span.End() + + if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return "", err + } authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) if err != nil { return "", err } - client, err := storage.GetClientByClientID(ctx, authReq.ClientID) - if err != nil { - return "", oidc.DefaultToServerError(err, "unable to retrieve client by id") - } authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes) if err != nil { return "", err } - if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil { - return "", err - } if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil { return "", err } @@ -240,49 +277,35 @@ func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) return maxAge, nil } -// ValidateAuthReqScopes validates the passed scopes +// ValidateAuthReqScopes validates the passed scopes and deletes any unsupported scopes. +// An error is returned if scopes is empty. func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) { if len(scopes) == 0 { return nil, oidc.ErrInvalidRequest(). WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + "If you have any questions, you may contact the administrator of the application.") } - openID := false - for i := len(scopes) - 1; i >= 0; i-- { - scope := scopes[i] - if scope == oidc.ScopeOpenID { - openID = true - continue - } - if !(scope == oidc.ScopeProfile || + scopes = slices.DeleteFunc(scopes, func(scope string) bool { + return !(scope == oidc.ScopeOpenID || + scope == oidc.ScopeProfile || scope == oidc.ScopeEmail || scope == oidc.ScopePhone || scope == oidc.ScopeAddress || scope == oidc.ScopeOfflineAccess) && - !client.IsScopeAllowed(scope) { - scopes[i] = scopes[len(scopes)-1] - scopes[len(scopes)-1] = "" - scopes = scopes[:len(scopes)-1] - } - } - if !openID { - return nil, oidc.ErrInvalidScope().WithDescription("The scope openid is missing in your request. " + - "Please ensure the scope openid is added to the request. " + - "If you have any questions, you may contact the administrator of the application.") - } - + !client.IsScopeAllowed(scope) + }) return scopes, nil } // checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores // other factors. func checkURIAgainstRedirects(client Client, uri string) error { - if str.Contains(client.RedirectURIs(), uri) { + if slices.Contains(client.RedirectURIs(), uri) { return nil } if globClient, ok := client.(HasRedirectGlobs); ok { for _, uriGlob := range globClient.RedirectURIGlobs() { - isMatch, err := path.Match(uriGlob, uri) + isMatch, err := doublestar.Match(uriGlob, uri) if err != nil { return oidc.ErrServerError().WithParent(err) } @@ -302,12 +325,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " + "Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.") } + if client.ApplicationType() == ApplicationTypeNative { + return validateAuthReqRedirectURINative(client, uri) + } if strings.HasPrefix(uri, "https://") { return checkURIAgainstRedirects(client, uri) } - if client.ApplicationType() == ApplicationTypeNative { - return validateAuthReqRedirectURINative(client, uri, responseType) - } if err := checkURIAgainstRedirects(client, uri); err != nil { return err } @@ -326,14 +349,17 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res } // ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type -func validateAuthReqRedirectURINative(client Client, uri string, responseType oidc.ResponseType) error { +func validateAuthReqRedirectURINative(client Client, uri string) error { parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri) - isCustomSchema := !strings.HasPrefix(uri, "http://") + isCustomSchema := !(strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://")) if err := checkURIAgainstRedirects(client, uri); err == nil { if client.DevMode() { return nil } - // The RedirectURIs are only valid for native clients when localhost or non-"http://" + if !isLoopback && strings.HasPrefix(uri, "https://") { + return nil + } + // The RedirectURIs are only valid for native clients when localhost or non-"http://" and "https://" if isLoopback || isCustomSchema { return nil } @@ -358,16 +384,16 @@ func equalURI(url1, url2 *url.URL) bool { return url1.Path == url2.Path && url1.RawQuery == url2.RawQuery } -func HTTPLoopbackOrLocalhost(rawurl string) (*url.URL, bool) { - parsedURL, err := url.Parse(rawurl) +func HTTPLoopbackOrLocalhost(rawURL string) (*url.URL, bool) { + parsedURL, err := url.Parse(rawURL) if err != nil { return nil, false } - if parsedURL.Scheme != "http" { - return nil, false + if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { + hostName := parsedURL.Hostname() + return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback() } - hostName := parsedURL.Hostname() - return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback() + return nil, false } // ValidateAuthReqResponseType validates the passed response_type to the registered response types @@ -385,14 +411,14 @@ func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) // ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request) // and returns the `sub` claim -func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier IDTokenHintVerifier) (string, error) { +func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) { if idTokenHint == "" { return "", nil } claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier) - if err != nil { + if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) { return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " + - "If you have any questions, you may contact the administrator of the application.") + "If you have any questions, you may contact the administrator of the application.").WithParent(err) } return claims.GetSubject(), nil } @@ -405,32 +431,49 @@ func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r * // AuthorizeCallback handles the callback after authentication in the Login UI func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { - params := mux.Vars(r) - id := params["id"] - if id == "" { - AuthRequestError(w, r, nil, fmt.Errorf("auth request callback is missing id"), authorizer.Encoder()) + ctx, span := tracer.Start(r.Context(), "AuthorizeCallback") + r = r.WithContext(ctx) + defer span.End() + + id, err := ParseAuthorizeCallbackRequest(r) + if err != nil { + AuthRequestError(w, r, nil, err, authorizer) return } - authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) if err != nil { - AuthRequestError(w, r, nil, err, authorizer.Encoder()) + AuthRequestError(w, r, nil, err, authorizer) return } if !authReq.Done() { AuthRequestError(w, r, authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."), - authorizer.Encoder()) + authorizer) return } AuthResponse(authReq, authorizer, w, r) } +func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) { + if err = r.ParseForm(); err != nil { + return "", fmt.Errorf("cannot parse form: %w", err) + } + id = r.Form.Get("id") + if id == "" { + return "", errors.New("auth request callback is missing id") + } + return id, nil +} + // AuthResponse creates the successful authentication response (either code or tokens) func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), "AuthResponse") + r = r.WithContext(ctx) + defer span.End() + client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } if authReq.GetResponseType() == oidc.ResponseTypeCode { @@ -440,39 +483,98 @@ func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWri AuthResponseToken(w, r, authReq, authorizer, client) } -// AuthResponseCode creates the successful code authentication response +// AuthResponseCode handles the creation of a successful authentication response using an authorization code func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { - code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto()) + ctx, span := tracer.Start(r.Context(), "AuthResponseCode") + defer span.End() + r = r.WithContext(ctx) + + var err error + if authReq.GetResponseMode() == oidc.ResponseModeFormPost { + err = handleFormPostResponse(w, r, authReq, authorizer) + } else { + err = handleRedirectResponse(w, r, authReq, authorizer) + } + if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return + AuthRequestError(w, r, authReq, err, authorizer) } - codeResponse := struct { - code string - state string - }{ - code: code, - state: authReq.GetState(), - } - callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) +} + +// handleFormPostResponse processes the authentication response using form post method +func handleFormPostResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error { + codeResponse, err := BuildAuthResponseCodeResponsePayload(r.Context(), authReq, authorizer) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) - return + return err } - http.Redirect(w, r, callback, http.StatusFound) + return AuthResponseFormPost(w, authReq.GetRedirectURI(), codeResponse, authorizer.Encoder()) +} + +// handleRedirectResponse processes the authentication response using the redirect method +func handleRedirectResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error { + callbackURL, err := BuildAuthResponseCallbackURL(r.Context(), authReq, authorizer) + if err != nil { + return err + } + http.Redirect(w, r, callbackURL, http.StatusFound) + return nil +} + +// BuildAuthResponseCodeResponsePayload generates the authorization code response payload for the authentication request +func BuildAuthResponseCodeResponsePayload(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (*CodeResponseType, error) { + code, err := CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto()) + if err != nil { + return nil, err + } + + sessionState := "" + if authRequestSessionState, ok := authReq.(AuthRequestSessionState); ok { + sessionState = authRequestSessionState.GetSessionState() + } + + return &CodeResponseType{ + Code: code, + State: authReq.GetState(), + SessionState: sessionState, + }, nil +} + +// BuildAuthResponseCallbackURL generates the callback URL for a successful authorization code response +func BuildAuthResponseCallbackURL(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (string, error) { + codeResponse, err := BuildAuthResponseCodeResponsePayload(ctx, authReq, authorizer) + if err != nil { + return "", err + } + + return AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), codeResponse, authorizer.Encoder()) } // AuthResponseToken creates the successful token(s) authentication response func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { + ctx, span := tracer.Start(r.Context(), "AuthResponseToken") + defer span.End() + r = r.WithContext(ctx) + createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "") if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } + + if authReq.GetResponseMode() == oidc.ResponseModeFormPost { + err := AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer) + return + } + + return + } + callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder()) if err != nil { - AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + AuthRequestError(w, r, authReq, err, authorizer) return } http.Redirect(w, r, callback, http.StatusFound) @@ -480,6 +582,9 @@ func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthReque // CreateAuthRequestCode creates and stores a code for the auth code response func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) { + ctx, span := tracer.Start(ctx, "CreateAuthRequestCode") + defer span.End() + code, err := BuildAuthRequestCode(authReq, crypto) if err != nil { return "", err @@ -497,7 +602,7 @@ func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { // AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values // depending on the response_mode and response_type -func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response interface{}, encoder httphelper.Encoder) (string, error) { +func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response any, encoder httphelper.Encoder) (string, error) { uri, err := url.Parse(redirectURI) if err != nil { return "", oidc.ErrServerError().WithParent(err) @@ -521,6 +626,43 @@ func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, respons return mergeQueryParams(uri, params), nil } +//go:embed form_post.html.tmpl +var formPostHtmlTemplate string + +var formPostTmpl = template.Must(template.New("form_post").Parse(formPostHtmlTemplate)) + +// AuthResponseFormPost responds a html page that automatically submits the form which contains the auth response parameters +func AuthResponseFormPost(res http.ResponseWriter, redirectURI string, response any, encoder httphelper.Encoder) error { + values := make(map[string][]string) + err := encoder.Encode(response, values) + if err != nil { + return oidc.ErrServerError().WithParent(err) + } + + params := &struct { + RedirectURI string + Params any + }{ + RedirectURI: redirectURI, + Params: values, + } + + var buf bytes.Buffer + err = formPostTmpl.Execute(&buf, params) + if err != nil { + return oidc.ErrServerError().WithParent(err) + } + + res.Header().Set("Cache-Control", "no-store") + res.WriteHeader(http.StatusOK) + _, err = buf.WriteTo(res) + if err != nil { + return oidc.ErrServerError().WithParent(err) + } + + return nil +} + func setFragment(uri *url.URL, params url.Values) string { uri.Fragment = params.Encode() return uri.String() diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go index 7a9701b..d1ea965 100644 --- a/pkg/op/auth_request_test.go +++ b/pkg/op/auth_request_test.go @@ -3,76 +3,54 @@ package op_test import ( "context" "errors" + "io" + "log/slog" "net/http" "net/http/httptest" "net/url" "reflect" "testing" - "github.com/gorilla/schema" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/schema" ) -// -// TOOD: tests will be implemented in branch for service accounts -// func TestAuthorize(t *testing.T) { -// // testCallback := func(t *testing.T, clienID string) callbackHandler { -// // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { -// // // require.Equal(t, clientID, client.) -// // } -// // } -// // testErr := func(t *testing.T, expected error) errorHandler { -// // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { -// // require.Equal(t, expected, err) -// // } -// // } -// type args struct { -// w http.ResponseWriter -// r *http.Request -// authorizer op.Authorizer -// } -// tests := []struct { -// name string -// args args -// }{ -// { -// "parsing fails", -// args{ -// httptest.NewRecorder(), -// &http.Request{Method: "POST", Body: nil}, -// mock.NewAuthorizerExpectValid(t, true), -// // testCallback(t, ""), -// // testErr(t, ErrInvalidRequest("cannot parse form")), -// }, -// }, -// { -// "decoding fails", -// args{ -// httptest.NewRecorder(), -// func() *http.Request { -// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) -// r.Header.Set("Content-Type", "application/x-www-form-urlencoded") -// return r -// }(), -// mock.NewAuthorizerExpectValid(t, true), -// // testCallback(t, ""), -// // testErr(t, ErrInvalidRequest("cannot parse auth request")), -// }, -// }, -// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) -// }) -// } -//} +func TestAuthorize(t *testing.T) { + tests := []struct { + name string + req *http.Request + expect func(a *mock.MockAuthorizerMockRecorder) + }{ + { + name: "parse error", // used to panic, see issue #315 + req: httptest.NewRequest(http.MethodPost, "/?;", nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + authorizer := mock.NewMockAuthorizer(gomock.NewController(t)) + + expect := authorizer.EXPECT() + expect.Decoder().Return(schema.NewDecoder()) + expect.Logger().Return(slog.Default()) + + if tt.expect != nil { + tt.expect(expect) + } + + op.Authorize(w, tt.req, authorizer) + }) + } +} func TestParseAuthorizeRequest(t *testing.T) { type args struct { @@ -147,7 +125,7 @@ func TestValidateAuthRequest(t *testing.T) { type args struct { authRequest *oidc.AuthRequest storage op.Storage - verifier op.IDTokenHintVerifier + verifier *op.IDTokenHintVerifier } tests := []struct { name string @@ -159,11 +137,6 @@ func TestValidateAuthRequest(t *testing.T) { args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil}, oidc.ErrInvalidRequest(), }, - { - "scope openid missing fails", - args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil}, - oidc.ErrInvalidScope(), - }, { "response_type missing fails", args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil}, @@ -309,16 +282,6 @@ func TestValidateAuthReqScopes(t *testing.T) { err: true, }, }, - { - "scope openid missing fails", - args{ - mock.NewClientExpectAny(t, op.ApplicationTypeWeb), - []string{"email"}, - }, - res{ - err: true, - }, - }, { "scope ok", args{ @@ -470,6 +433,24 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { }, false, }, + { + "code flow registered https loopback v4 native ok", + args{ + "https://127.0.0.1:4200/callback", + mock.NewClientWithConfig(t, []string{"https://127.0.0.1/callback"}, op.ApplicationTypeNative, nil, false), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow registered https loopback v6 native ok", + args{ + "https://[::1]:4200/callback", + mock.NewClientWithConfig(t, []string{"https://[::1]/callback"}, op.ApplicationTypeNative, nil, false), + oidc.ResponseTypeCode, + }, + false, + }, { "code flow unregistered http native fails", args{ @@ -605,6 +586,60 @@ func TestValidateAuthReqRedirectURI(t *testing.T) { }, false, }, + { + "code flow dev mode has redirect globs regular ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs wildcard ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs double star ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs double star ok", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs IPv6 ok", + args{ + "http://[::1]:80/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://\\[::1\\]:80/*"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + false, + }, + { + "code flow dev mode has redirect globs bad pattern", + args{ + "http://registered.com/callback", + mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/\\"}, op.ApplicationTypeUserAgent, nil, true), + oidc.ResponseTypeCode, + }, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -769,7 +804,7 @@ func TestAuthResponseURL(t *testing.T) { redirectURI string responseType oidc.ResponseType responseMode oidc.ResponseMode - response interface{} + response any encoder httphelper.Encoder } type res struct { @@ -787,7 +822,7 @@ func TestAuthResponseURL(t *testing.T) { "uri", oidc.ResponseTypeCode, "", - map[string]interface{}{"test": "test"}, + map[string]any{"test": "test"}, &mockEncoder{ errors.New("error encoding"), }, @@ -958,7 +993,7 @@ type mockEncoder struct { err error } -func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { +func (m *mockEncoder) Encode(src any, dst map[string][]string) error { if m.err != nil { return m.err } @@ -967,3 +1002,611 @@ func (m *mockEncoder) Encode(src interface{}, dst map[string][]string) error { } return nil } + +// mockCrypto implements the op.Crypto interface +// and in always equals out. (It doesn't crypt anything). +// When returnErr != nil, that error is always returned instread. +type mockCrypto struct { + returnErr error +} + +func (c *mockCrypto) Encrypt(s string) (string, error) { + if c.returnErr != nil { + return "", c.returnErr + } + return s, nil +} + +func (c *mockCrypto) Decrypt(s string) (string, error) { + if c.returnErr != nil { + return "", c.returnErr + } + return s, nil +} + +func TestAuthResponseCode(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantCode int + wantLocationHeader string + wantCacheControlHeader string + wantBody string + } + tests := []struct { + name string + args args + res res + }{ + { + name: "create code error", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + authorizer.EXPECT().Logger().Return(slog.Default()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1&state=state1", + wantBody: "", + }, + }, + { + name: "success with state and session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1", + wantBody: "", + }, + }, + { + name: "success without state", // reproduce issue #415 + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusFound, + wantLocationHeader: "/auth/callback/?code=id1", + wantBody: "", + }, + }, + { + name: "success form_post", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + ResponseMode: "form_post", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantCode: http.StatusOK, + wantCacheControlHeader: "no-store", + wantBody: "\n\n\n\n
\n\n\n\n\n\n\n
\n\n", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/auth/callback/", nil) + w := httptest.NewRecorder() + op.AuthResponseCode(w, r, tt.args.authReq, tt.args.authorizer(t)) + resp := w.Result() + defer resp.Body.Close() + assert.Equal(t, tt.res.wantCode, resp.StatusCode) + assert.Equal(t, tt.res.wantLocationHeader, resp.Header.Get("Location")) + assert.Equal(t, tt.res.wantCacheControlHeader, resp.Header.Get("Cache-Control")) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, tt.res.wantBody, string(body)) + }) + } +} + +func Test_parseAuthorizeCallbackRequest(t *testing.T) { + tests := []struct { + name string + url string + wantId string + wantErr bool + }{ + { + name: "parse error", + url: "/?id;=99", + wantErr: true, + }, + { + name: "missing id", + url: "/", + wantErr: true, + }, + { + name: "ok", + url: "/?id=99", + wantId: "99", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, tt.url, nil) + gotId, err := op.ParseAuthorizeCallbackRequest(r) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.wantId, gotId) + }) + } +} + +func TestBuildAuthResponseCodeResponsePayload(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantCode string + wantState string + wantSessionState string + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "create code error", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "state1", + }, + }, + { + name: "success without state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + TransferState: "", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "", + }, + }, + { + name: "success with session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + return authorizer + }, + }, + res: res{ + wantCode: "id1", + wantState: "state1", + wantSessionState: "session_state1", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.BuildAuthResponseCodeResponsePayload(context.Background(), tt.args.authReq, tt.args.authorizer(t)) + if tt.res.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.res.wantCode, got.Code) + assert.Equal(t, tt.res.wantState, got.State) + assert.Equal(t, tt.res.wantSessionState, got.SessionState) + }) + } +} + +func TestValidateAuthReqIDTokenHint(t *testing.T) { + token, _ := tu.ValidIDToken() + tests := []struct { + name string + idTokenHint string + want string + wantErr error + }{ + { + name: "empty", + }, + { + name: "verify err", + idTokenHint: "foo", + wantErr: oidc.ErrLoginRequired(), + }, + { + name: "ok", + idTokenHint: token, + want: tu.ValidSubject, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{})) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestBuildAuthResponseCallbackURL(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer func(*testing.T) op.Authorizer + } + type res struct { + wantURL string + wantErr bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "error when generating code response", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{ + returnErr: io.ErrClosedPipe, + }) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "error when generating callback URL", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "://invalid-url", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantErr: true, + }, + }, + { + name: "success with state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1&state=state1", + wantErr: false, + }, + }, + { + name: "success without state", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1", + wantErr: false, + }, + }, + { + name: "success with session_state", + args: args{ + authReq: &storage.AuthRequestWithSessionState{ + AuthRequest: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + }, + SessionState: "session_state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?code=id1&session_state=session_state1&state=state1", + wantErr: false, + }, + }, + { + name: "success with existing query parameters", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback?param=value", + TransferState: "state1", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback?param=value&code=id1&state=state1", + wantErr: false, + }, + }, + { + name: "success with fragment response mode", + args: args{ + authReq: &storage.AuthRequest{ + ID: "id1", + CallbackURI: "https://example.com/callback", + TransferState: "state1", + ResponseMode: "fragment", + }, + authorizer: func(t *testing.T) op.Authorizer { + ctrl := gomock.NewController(t) + storage := mock.NewMockStorage(ctrl) + storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1") + + authorizer := mock.NewMockAuthorizer(ctrl) + authorizer.EXPECT().Storage().Return(storage) + authorizer.EXPECT().Crypto().Return(&mockCrypto{}) + authorizer.EXPECT().Encoder().Return(schema.NewEncoder()) + return authorizer + }, + }, + res: res{ + wantURL: "https://example.com/callback#code=id1&state=state1", + wantErr: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := op.BuildAuthResponseCallbackURL(context.Background(), tt.args.authReq, tt.args.authorizer(t)) + if tt.res.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + + if tt.res.wantURL != "" { + // Parse the URLs to compare components instead of direct string comparison + expectedURL, err := url.Parse(tt.res.wantURL) + require.NoError(t, err) + actualURL, err := url.Parse(got) + require.NoError(t, err) + + // Compare the base parts (scheme, host, path) + assert.Equal(t, expectedURL.Scheme, actualURL.Scheme) + assert.Equal(t, expectedURL.Host, actualURL.Host) + assert.Equal(t, expectedURL.Path, actualURL.Path) + + // Compare the fragment if any + assert.Equal(t, expectedURL.Fragment, actualURL.Fragment) + + // For query parameters, compare them independently of order + expectedQuery := expectedURL.Query() + actualQuery := actualURL.Query() + + assert.Equal(t, len(expectedQuery), len(actualQuery), "Query parameter count does not match") + + for key, expectedValues := range expectedQuery { + actualValues, exists := actualQuery[key] + assert.True(t, exists, "Expected query parameter %s not found", key) + assert.ElementsMatch(t, expectedValues, actualValues, "Values for parameter %s don't match", key) + } + } + }) + } +} diff --git a/pkg/op/client.go b/pkg/op/client.go index 9da44a7..a4f44d3 100644 --- a/pkg/op/client.go +++ b/pkg/op/client.go @@ -7,8 +7,8 @@ import ( "net/url" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) //go:generate go get github.com/dmarkham/enumer @@ -63,6 +63,7 @@ type Client interface { // such as DevMode for the client being enabled. // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest type HasRedirectGlobs interface { + Client RedirectURIGlobs() []string PostLogoutRedirectURIGlobs() []string } @@ -87,10 +88,13 @@ var ( ) type ClientJWTProfile interface { - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) { + ctx, span := tracer.Start(ctx, "ClientJWTAuth") + defer span.End() + if ca.ClientAssertion == "" { return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) } @@ -103,6 +107,10 @@ func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier } func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) { + ctx, span := tracer.Start(r.Context(), "ClientBasicAuth") + r = r.WithContext(ctx) + defer span.End() + clientID, clientSecret, ok := r.BasicAuth() if !ok { return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials) @@ -150,24 +158,44 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err) } + ctx, span := tracer.Start(r.Context(), "ClientIDFromRequest") + r = r.WithContext(ctx) + defer span.End() + data := new(clientData) - if err = p.Decoder().Decode(data, r.PostForm); err != nil { + if err = p.Decoder().Decode(data, r.Form); err != nil { return "", false, err } JWTProfile, ok := p.(ClientJWTProfile) - if ok { + if ok && data.ClientAssertion != "" { + // if JWTProfile is supported and client sent an assertion, check it and use it as response + // regardless if it succeeded or failed clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile) + return clientID, err == nil, err } - if !ok || errors.Is(err, ErrNoClientCredentials) { - clientID, err = ClientBasicAuth(r, p.Storage()) - } + // try basic auth + clientID, err = ClientBasicAuth(r, p.Storage()) + // if that succeeded, use it if err == nil { return clientID, true, nil } + // if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials` + // but return other errors immediately + if !errors.Is(err, ErrNoClientCredentials) { + return "", false, err + } + // if the client did not authenticate (public clients) it must at least send a client_id if data.ClientID == "" { return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID) } return data.ClientID, false, nil } + +type ClientCredentials struct { + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body + ClientAssertion string `schema:"client_assertion"` // JWT + ClientAssertionType string `schema:"client_assertion_type"` +} diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go index 1af4157..b416630 100644 --- a/pkg/op/client_test.go +++ b/pkg/op/client_test.go @@ -10,19 +10,19 @@ import ( "strings" "testing" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "github.com/zitadel/schema" ) type testClientJWTProfile struct{} -func (testClientJWTProfile) JWTProfileVerifier(context.Context) op.JWTProfileVerifier { return nil } +func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil } func TestClientJWTAuth(t *testing.T) { type args struct { @@ -108,7 +108,7 @@ func TestClientBasicAuth(t *testing.T) { }, storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "wrong").Return(errWrong) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "wrong").Return(errWrong) return s }(), wantErr: errWrong, @@ -121,7 +121,7 @@ func TestClientBasicAuth(t *testing.T) { }, storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil) return s }(), wantClientID: "foo", @@ -207,7 +207,7 @@ func TestClientIDFromRequest(t *testing.T) { p: testClientProvider{ storage: func() op.Storage { s := mock.NewMockStorage(gomock.NewController(t)) - s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil) + s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil) return s }(), }, diff --git a/pkg/op/config.go b/pkg/op/config.go index c40ed39..b271765 100644 --- a/pkg/op/config.go +++ b/pkg/op/config.go @@ -2,10 +2,12 @@ package op import ( "errors" + "log" "net/http" "net/url" "strings" + "github.com/muhlemmer/httpforwarded" "golang.org/x/text/language" ) @@ -20,14 +22,15 @@ var ( type Configuration interface { IssuerFromRequest(r *http.Request) string Insecure() bool - AuthorizationEndpoint() Endpoint - TokenEndpoint() Endpoint - IntrospectionEndpoint() Endpoint - UserinfoEndpoint() Endpoint - RevocationEndpoint() Endpoint - EndSessionEndpoint() Endpoint - KeysEndpoint() Endpoint - DeviceAuthorizationEndpoint() Endpoint + AuthorizationEndpoint() *Endpoint + TokenEndpoint() *Endpoint + IntrospectionEndpoint() *Endpoint + UserinfoEndpoint() *Endpoint + RevocationEndpoint() *Endpoint + EndSessionEndpoint() *Endpoint + KeysEndpoint() *Endpoint + DeviceAuthorizationEndpoint() *Endpoint + CheckSessionIframe() *Endpoint AuthMethodPostSupported() bool CodeMethodS256Supported() bool @@ -47,11 +50,53 @@ type Configuration interface { SupportedUILocales() []language.Tag DeviceAuthorization() DeviceAuthorizationConfig + + BackChannelLogoutSupported() bool + BackChannelLogoutSessionSupported() bool } type IssuerFromRequest func(r *http.Request) string func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { + return issuerFromForwardedOrHost(path, new(issuerConfig)) +} + +type IssuerFromOption func(c *issuerConfig) + +// WithIssuerFromCustomHeaders can be used to customize the header names used. +// The same rules apply where the first successful host is returned. +func WithIssuerFromCustomHeaders(headers ...string) IssuerFromOption { + return func(c *issuerConfig) { + for i, h := range headers { + headers[i] = http.CanonicalHeaderKey(h) + } + c.headers = headers + } +} + +type issuerConfig struct { + headers []string +} + +// IssuerFromForwardedOrHost tries to establish the Issuer based +// on the Forwarded header host field. +// If multiple Forwarded headers are present, the first mention +// of the host field will be used. +// If the Forwarded header is not present, no host field is found, +// or there is a parser error the Request Host will be used as a fallback. +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded +func IssuerFromForwardedOrHost(path string, opts ...IssuerFromOption) func(bool) (IssuerFromRequest, error) { + c := &issuerConfig{ + headers: []string{http.CanonicalHeaderKey("forwarded")}, + } + for _, opt := range opts { + opt(c) + } + + return issuerFromForwardedOrHost(path, c) +} + +func issuerFromForwardedOrHost(path string, c *issuerConfig) func(bool) (IssuerFromRequest, error) { return func(allowInsecure bool) (IssuerFromRequest, error) { issuerPath, err := url.Parse(path) if err != nil { @@ -61,11 +106,28 @@ func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) { return nil, err } return func(r *http.Request) string { + if host, ok := hostFromForwarded(r, c.headers); ok { + return dynamicIssuer(host, path, allowInsecure) + } return dynamicIssuer(r.Host, path, allowInsecure) }, nil } } +func hostFromForwarded(r *http.Request, headers []string) (host string, ok bool) { + for _, header := range headers { + hosts, err := httpforwarded.ParseParameter("host", r.Header[header]) + if err != nil { + log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch + continue + } + if len(hosts) > 0 { + return hosts[0], true + } + } + return "", false +} + func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) { return func(allowInsecure bool) (IssuerFromRequest, error) { if err := ValidateIssuer(issuer, allowInsecure); err != nil { diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go index cfe4e61..d739348 100644 --- a/pkg/op/config_test.go +++ b/pkg/op/config_test.go @@ -1,11 +1,13 @@ package op import ( + "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestValidateIssuer(t *testing.T) { @@ -234,7 +236,7 @@ func TestIssuerFromHost(t *testing.T) { }, }, { - "custom path unsecure", + "custom path insecure", args{ path: "/custom/", allowInsecure: true, @@ -261,6 +263,132 @@ func TestIssuerFromHost(t *testing.T) { } } +func TestIssuerFromForwardedOrHost(t *testing.T) { + type args struct { + path string + opts []IssuerFromOption + target string + header map[string][]string + } + type res struct { + issuer string + } + tests := []struct { + name string + args args + res res + }{ + { + "header parse error", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": {"~~~~"}}, + }, + res{ + issuer: "https://issuer.com/custom/", + }, + }, + { + "no forwarded header", + args{ + path: "/custom/", + target: "https://issuer.com", + }, + res{ + issuer: "https://issuer.com/custom/", + }, + }, + // by=;for=;host=;proto= + { + "forwarded header without host", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": { + `by=identifier;for=identifier;proto=https`, + }}, + }, + res{ + issuer: "https://issuer.com/custom/", + }, + }, + { + "forwarded header with host", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https`, + }}, + }, + res{ + issuer: "https://first.com/custom/", + }, + }, + { + "forwarded header with multiple hosts", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, + }}, + }, + res{ + issuer: "https://first.com/custom/", + }, + }, + { + "multiple forwarded headers hosts", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{"Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, + `by=identifier;for=identifier;host=third.com;proto=https`, + }}, + }, + res{ + issuer: "https://first.com/custom/", + }, + }, + { + "custom header first", + args{ + path: "/custom/", + target: "https://issuer.com", + header: map[string][]string{ + "Forwarded": { + `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`, + `by=identifier;for=identifier;host=third.com;proto=https`, + }, + "X-Custom-Forwarded": { + `by=identifier;for=identifier;host=custom.com;proto=https,host=custom2.com`, + }, + }, + opts: []IssuerFromOption{ + WithIssuerFromCustomHeaders("x-custom-forwarded"), + }, + }, + res{ + issuer: "https://custom.com/custom/", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + issuer, err := IssuerFromForwardedOrHost(tt.args.path, tt.args.opts...)(false) + require.NoError(t, err) + req := httptest.NewRequest("", tt.args.target, nil) + for k, v := range tt.args.header { + req.Header[http.CanonicalHeaderKey(k)] = v + } + assert.Equal(t, tt.res.issuer, issuer(req)) + }) + } +} + func TestStaticIssuer(t *testing.T) { type args struct { issuer string diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go index 6786022..01aaad3 100644 --- a/pkg/op/crypto.go +++ b/pkg/op/crypto.go @@ -1,7 +1,7 @@ package op import ( - "github.com/zitadel/oidc/v2/pkg/crypto" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" ) type Crypto interface { diff --git a/pkg/op/device.go b/pkg/op/device.go index 04c06f2..866cbc4 100644 --- a/pkg/op/device.go +++ b/pkg/op/device.go @@ -8,17 +8,26 @@ import ( "fmt" "math/big" "net/http" + "net/url" + "slices" "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type DeviceAuthorizationConfig struct { Lifetime time.Duration PollInterval time.Duration - UserFormURL string // the URL where the user must go to authorize the device + + // UserFormURL is the complete URL where the user must go to authorize the device. + // Deprecated: use UserFormPath instead. + UserFormURL string + + // UserFormPath is the path where the user must go to authorize the device. + // The hostname for the URL is taken from the request by IssuerFromContext. + UserFormPath string UserCode UserCodeConfig } @@ -49,58 +58,94 @@ var ( func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := DeviceAuthorization(w, r, o); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, o.Logger()) } } } func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error { - storage, err := assertDeviceStorage(o.Storage()) - if err != nil { - return err - } + ctx, span := tracer.Start(r.Context(), "DeviceAuthorization") + r = r.WithContext(ctx) + defer span.End() req, err := ParseDeviceCodeRequest(r, o) if err != nil { return err } - - config := o.DeviceAuthorization() - - deviceCode, err := NewDeviceCode(RecommendedDeviceCodeBytes) + response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o) if err != nil { return err } - userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) - if err != nil { - return err - } - - expires := time.Now().Add(config.Lifetime) - err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes) - if err != nil { - return err - } - - response := &oidc.DeviceAuthorizationResponse{ - DeviceCode: deviceCode, - UserCode: userCode, - VerificationURI: config.UserFormURL, - ExpiresIn: int(config.Lifetime / time.Second), - Interval: int(config.PollInterval / time.Second), - } - - response.VerificationURIComplete = fmt.Sprintf("%s?user_code=%s", config.UserFormURL, userCode) httphelper.MarshalJSON(w, response) return nil } +func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) { + ctx, span := tracer.Start(ctx, "createDeviceAuthorization") + defer span.End() + + storage, err := assertDeviceStorage(o.Storage()) + if err != nil { + return nil, err + } + config := o.DeviceAuthorization() + + deviceCode, _ := NewDeviceCode(RecommendedDeviceCodeBytes) + userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + + expires := time.Now().Add(config.Lifetime) + err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes) + if err != nil { + return nil, NewStatusError(err, http.StatusInternalServerError) + } + + var verification *url.URL + if config.UserFormURL != "" { + if verification, err = url.Parse(config.UserFormURL); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form") + return nil, NewStatusError(err, http.StatusInternalServerError) + } + } else { + if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil { + err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer") + return nil, NewStatusError(err, http.StatusInternalServerError) + } + verification.Path = config.UserFormPath + } + + response := &oidc.DeviceAuthorizationResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: verification.String(), + ExpiresIn: int(config.Lifetime / time.Second), + Interval: int(config.PollInterval / time.Second), + } + + verification.RawQuery = "user_code=" + userCode + response.VerificationURIComplete = verification.String() + return response, nil +} + func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) { + ctx, span := tracer.Start(r.Context(), "ParseDeviceCodeRequest") + r = r.WithContext(ctx) + defer span.End() + clientID, _, err := ClientIDFromRequest(r, o) if err != nil { return nil, err } + client, err := o.Storage().GetClientByClientID(r.Context(), clientID) + if err != nil { + return nil, err + } + if !ValidateGrantType(client, oidc.GrantTypeDeviceCode) { + return nil, oidc.ErrUnauthorizedClient().WithDescription("client missing grant type " + string(oidc.GrantTypeCode)) + } req := new(oidc.DeviceAuthorizationRequest) if err := o.Decoder().Decode(req, r.Form); err != nil { @@ -115,11 +160,14 @@ func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuth // results in a 22 character base64 encoded string. const RecommendedDeviceCodeBytes = 16 +// NewDeviceCode generates a new cryptographically secure device code as a base64 encoded string. +// The length of the string is nBytes * 4 / 3. +// An error is never returned. +// +// TODO(v4): change return type to string alone. func NewDeviceCode(nBytes int) (string, error) { bytes := make([]byte, nBytes) - if _, err := rand.Read(bytes); err != nil { - return "", fmt.Errorf("%w getting entropy for device code", err) - } + rand.Read(bytes) return base64.RawURLEncoding.EncodeToString(bytes), nil } @@ -149,27 +197,13 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) { return buf.String(), nil } -type deviceAccessTokenRequest struct { - subject string - audience []string - scopes []string -} - -func (r *deviceAccessTokenRequest) GetSubject() string { - return r.subject -} - -func (r *deviceAccessTokenRequest) GetAudience() []string { - return r.audience -} - -func (r *deviceAccessTokenRequest) GetScopes() []string { - return r.scopes -} - func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "DeviceAccessToken") + defer span.End() + r = r.WithContext(ctx) + if err := deviceAccessToken(w, r, exchanger); err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } } @@ -189,7 +223,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang if err != nil { return err } - state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) + tokenRequest, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger) if err != nil { return err } @@ -203,11 +237,6 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang WithDescription("confidential client requires authentication") } - tokenRequest := &deviceAccessTokenRequest{ - subject: state.Subject, - audience: []string{clientID}, - scopes: state.Scopes, - } resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client) if err != nil { return err @@ -225,7 +254,54 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc. return req, nil } +// DeviceAuthorizationState describes the current state of +// the device authorization flow. +// It implements the [IDTokenRequest] interface. +type DeviceAuthorizationState struct { + ClientID string + Audience []string + Scopes []string + Expires time.Time // The time after we consider the authorization request timed-out + Done bool // The user authenticated and approved the authorization request + Denied bool // The user authenticated and denied the authorization request + + // The following fields are populated after Done == true + Subject string + AMR []string + AuthTime time.Time +} + +func (r *DeviceAuthorizationState) GetAMR() []string { + return r.AMR +} + +func (r *DeviceAuthorizationState) GetAudience() []string { + if !slices.Contains(r.Audience, r.ClientID) { + r.Audience = append(r.Audience, r.ClientID) + } + return r.Audience +} + +func (r *DeviceAuthorizationState) GetAuthTime() time.Time { + return r.AuthTime +} + +func (r *DeviceAuthorizationState) GetClientID() string { + return r.ClientID +} + +func (r *DeviceAuthorizationState) GetScopes() []string { + return r.Scopes +} + +func (r *DeviceAuthorizationState) GetSubject() string { + return r.Subject +} + func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) { + ctx, span := tracer.Start(ctx, "CheckDeviceAuthorizationState") + defer span.End() + storage, err := assertDeviceStorage(exchanger.Storage()) if err != nil { return nil, err @@ -250,16 +326,34 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str return state, oidc.ErrAuthorizationPending() } -func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client AccessTokenClient) (*oidc.AccessTokenResponse, error) { - accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, AccessTokenTypeBearer, creator, client, "") +func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { + /* TODO(v4): + Change the TokenRequest argument type to *DeviceAuthorizationState. + Breaking change that can not be done for v3. + */ + ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse") + defer span.End() + + accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err } - return &oidc.AccessTokenResponse{ + response := &oidc.AccessTokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), - }, nil + Scope: tokenRequest.GetScopes(), + } + + // TODO(v4): remove type assertion + if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && slices.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) { + response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client) + if err != nil { + return nil, err + } + } + + return response, nil } diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go index 69ba102..a7b5c4e 100644 --- a/pkg/op/device_test.go +++ b/pkg/op/device_test.go @@ -13,36 +13,69 @@ import ( "testing" "time" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" ) func Test_deviceAuthorizationHandler(t *testing.T) { - req := &oidc.DeviceAuthorizationRequest{ - Scopes: []string{"foo", "bar"}, - ClientID: "web", + type conf struct { + UserFormURL string + UserFormPath string } - values := make(url.Values) - testProvider.Encoder().Encode(req, values) - body := strings.NewReader(values.Encode()) + tests := []struct { + name string + conf conf + }{ + { + name: "UserFormURL", + conf: conf{ + UserFormURL: "https://localhost:9998/device", + }, + }, + { + name: "UserFormPath", + conf: conf{ + UserFormPath: "/device", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := gu.PtrCopy(testConfig) + conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL + conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath + provider := newTestProvider(conf) - r := httptest.NewRequest(http.MethodPost, "/", body) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req := &oidc.DeviceAuthorizationRequest{ + Scopes: []string{"foo", "bar"}, + ClientID: "device", + } + values := make(url.Values) + testProvider.Encoder().Encode(req, values) + body := strings.NewReader(values.Encode()) - w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/", body) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer)) - runWithRandReader(mr.New(mr.NewSource(1)), func() { - op.DeviceAuthorizationHandler(testProvider)(w, r) - }) + w := httptest.NewRecorder() - result := w.Result() + runWithRandReader(mr.New(mr.NewSource(1)), func() { + op.DeviceAuthorizationHandler(provider)(w, r) + }) - assert.Less(t, result.StatusCode, 300) + result := w.Result() - got, _ := io.ReadAll(result.Body) - assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + assert.Less(t, result.StatusCode, 300) + + got, _ := io.ReadAll(result.Body) + assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got)) + }) + } } func TestParseDeviceCodeRequest(t *testing.T) { @@ -56,11 +89,27 @@ func TestParseDeviceCodeRequest(t *testing.T) { wantErr: true, }, { - name: "success", + name: "missing grant type", req: &oidc.DeviceAuthorizationRequest{ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, ClientID: "web", }, + wantErr: true, + }, + { + name: "client not found", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "foobar", + }, + wantErr: true, + }, + { + name: "success", + req: &oidc.DeviceAuthorizationRequest{ + Scopes: oidc.SpaceDelimitedArray{"foo", "bar"}, + ClientID: "device", + }, }, } for _, tt := range tests { @@ -78,8 +127,7 @@ func TestParseDeviceCodeRequest(t *testing.T) { got, err := op.ParseDeviceCodeRequest(r, testProvider) if tt.wantErr { require.Error(t, err) - } else { - require.NoError(t, err) + return } assert.Equal(t, tt.req, got) }) @@ -97,21 +145,11 @@ func runWithRandReader(r io.Reader, f func()) { } func TestNewDeviceCode(t *testing.T) { - t.Run("reader error", func(t *testing.T) { - runWithRandReader(errReader{}, func() { - _, err := op.NewDeviceCode(16) - require.Error(t, err) - }) - }) - - t.Run("different lengths, rand reader", func(t *testing.T) { - for i := 1; i <= 32; i++ { - got, err := op.NewDeviceCode(i) - require.NoError(t, err) - assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i)) - } - }) - + for i := 1; i <= 32; i++ { + got, err := op.NewDeviceCode(i) + require.NoError(t, err) + assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i)) + } } func TestNewUserCode(t *testing.T) { @@ -272,7 +310,7 @@ func BenchmarkNewUserCode(b *testing.B) { } func TestDeviceAccessToken(t *testing.T) { - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"}) storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim") @@ -297,7 +335,7 @@ func TestDeviceAccessToken(t *testing.T) { func TestCheckDeviceAuthorizationState(t *testing.T) { now := time.Now() - storage := testProvider.Storage().(op.DeviceAuthorizationStorage) + storage := testProvider.Storage().(*storage.Storage) storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"}) storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"}) @@ -405,3 +443,96 @@ func TestCheckDeviceAuthorizationState(t *testing.T) { }) } } + +func TestCreateDeviceTokenResponse(t *testing.T) { + tests := []struct { + name string + tokenRequest op.TokenRequest + wantAccessToken bool + wantRefreshToken bool + wantIDToken bool + wantErr bool + }{ + { + name: "access token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + }, + wantAccessToken: true, + }, + { + name: "access and refresh tokens", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess}, + }, + wantAccessToken: true, + wantRefreshToken: true, + }, + { + name: "access and id token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantIDToken: true, + }, + { + name: "access, refresh and id token", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "id1", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantAccessToken: true, + wantRefreshToken: true, + wantIDToken: true, + }, + { + name: "id token creation error", + tokenRequest: &op.DeviceAuthorizationState{ + ClientID: "client1", + Subject: "foobar", + AMR: []string{"password"}, + AuthTime: time.Now(), + Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native") + require.NoError(t, err) + + got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.InDelta(t, 300, got.ExpiresIn, 2) + if tt.wantAccessToken { + assert.NotEmpty(t, got.AccessToken, "access token") + } + if tt.wantRefreshToken { + assert.NotEmpty(t, got.RefreshToken, "refresh token") + } + if tt.wantIDToken { + assert.NotEmpty(t, got.IDToken, "id token") + } + }) + } +} diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go index 26f89eb..9b3ddb6 100644 --- a/pkg/op/discovery.go +++ b/pkg/op/discovery.go @@ -4,10 +4,10 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type DiscoverStorage interface { @@ -25,7 +25,7 @@ var DefaultSupportedScopes = []string{ func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Discover(w, CreateDiscoveryConfig(r, c, s)) + Discover(w, CreateDiscoveryConfig(r.Context(), c, s)) } } @@ -33,8 +33,8 @@ func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { httphelper.MarshalJSON(w, config) } -func CreateDiscoveryConfig(r *http.Request, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { - issuer := config.IssuerFromRequest(r) +func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) return &oidc.DiscoveryConfiguration{ Issuer: issuer, AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer), @@ -45,11 +45,12 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer), JwksURI: config.KeysEndpoint().Absolute(issuer), DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer), + CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer), ScopesSupported: Scopes(config), ResponseTypesSupported: ResponseTypes(config), GrantTypesSupported: GrantTypes(config), SubjectTypesSupported: SubjectTypes(config), - IDTokenSigningAlgValuesSupported: SigAlgorithms(r.Context(), storage), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), @@ -61,11 +62,50 @@ func CreateDiscoveryConfig(r *http.Request, config Configuration, storage Discov CodeChallengeMethodsSupported: CodeChallengeMethods(config), UILocalesSupported: config.SupportedUILocales(), RequestParameterSupported: config.RequestObjectSupported(), + BackChannelLogoutSupported: config.BackChannelLogoutSupported(), + BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(), + } +} + +func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration { + issuer := IssuerFromContext(ctx) + return &oidc.DiscoveryConfiguration{ + Issuer: issuer, + AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer), + TokenEndpoint: endpoints.Token.Absolute(issuer), + IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer), + UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer), + RevocationEndpoint: endpoints.Revocation.Absolute(issuer), + EndSessionEndpoint: endpoints.EndSession.Absolute(issuer), + JwksURI: endpoints.JwksURI.Absolute(issuer), + DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer), + ScopesSupported: Scopes(config), + ResponseTypesSupported: ResponseTypes(config), + GrantTypesSupported: GrantTypes(config), + SubjectTypesSupported: SubjectTypes(config), + IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage), + RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config), + TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config), + TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config), + IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config), + IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config), + RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config), + RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config), + ClaimsSupported: SupportedClaims(config), + CodeChallengeMethodsSupported: CodeChallengeMethods(config), + UILocalesSupported: config.SupportedUILocales(), + RequestParameterSupported: config.RequestObjectSupported(), + BackChannelLogoutSupported: config.BackChannelLogoutSupported(), + BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(), } } func Scopes(c Configuration) []string { - return DefaultSupportedScopes // TODO: config + provider, ok := c.(*Provider) + if ok && provider.config.SupportedScopes != nil { + return provider.config.SupportedScopes + } + return DefaultSupportedScopes } func ResponseTypes(c Configuration) []string { @@ -100,10 +140,13 @@ func GrantTypes(c Configuration) []oidc.GrantType { } func SubjectTypes(c Configuration) []string { - return []string{"public"} //TODO: config + return []string{"public"} // TODO: config } func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string { + ctx, span := tracer.Start(ctx, "SigAlgorithms") + defer span.End() + algorithms, err := storage.SignatureAlgorithms(ctx) if err != nil { return nil @@ -182,32 +225,12 @@ func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod { } func SupportedClaims(c Configuration) []string { - return []string{ // TODO: config - "sub", - "aud", - "exp", - "iat", - "iss", - "auth_time", - "nonce", - "acr", - "amr", - "c_hash", - "at_hash", - "act", - "scopes", - "client_id", - "azp", - "preferred_username", - "name", - "family_name", - "given_name", - "locale", - "email", - "email_verified", - "phone_number", - "phone_number_verified", + provider, ok := c.(*Provider) + if ok && provider.config.SupportedClaims != nil { + return provider.config.SupportedClaims } + + return DefaultSupportedClaims } func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod { diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go index 2d0b8af..63f1b98 100644 --- a/pkg/op/discovery_test.go +++ b/pkg/op/discovery_test.go @@ -6,14 +6,14 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v4" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock" ) func TestDiscover(t *testing.T) { @@ -48,9 +48,9 @@ func TestDiscover(t *testing.T) { func TestCreateDiscoveryConfig(t *testing.T) { type args struct { - request *http.Request - c op.Configuration - s op.DiscoverStorage + ctx context.Context + c op.Configuration + s op.DiscoverStorage } tests := []struct { name string @@ -61,7 +61,7 @@ func TestCreateDiscoveryConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := op.CreateDiscoveryConfig(tt.args.request, tt.args.c, tt.args.s) + got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s) assert.Equal(t, tt.want, got) }) } @@ -81,6 +81,11 @@ func Test_scopes(t *testing.T) { args{}, op.DefaultSupportedScopes, }, + { + "custom scopes", + args{newTestProvider(&op.Config{SupportedScopes: []string{"test1", "test2"}})}, + []string{"test1", "test2"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go index b1e1507..1ac1cad 100644 --- a/pkg/op/endpoint.go +++ b/pkg/op/endpoint.go @@ -1,32 +1,46 @@ package op -import "strings" +import ( + "errors" + "strings" +) type Endpoint struct { path string url string } -func NewEndpoint(path string) Endpoint { - return Endpoint{path: path} +func NewEndpoint(path string) *Endpoint { + return &Endpoint{path: path} } -func NewEndpointWithURL(path, url string) Endpoint { - return Endpoint{path: path, url: url} +func NewEndpointWithURL(path, url string) *Endpoint { + return &Endpoint{path: path, url: url} } -func (e Endpoint) Relative() string { +func (e *Endpoint) Relative() string { + if e == nil { + return "" + } return relativeEndpoint(e.path) } -func (e Endpoint) Absolute(host string) string { +func (e *Endpoint) Absolute(host string) string { + if e == nil { + return "" + } if e.url != "" { return e.url } return absoluteEndpoint(host, e.path) } -func (e Endpoint) Validate() error { +var ErrNilEndpoint = errors.New("nil endpoint") + +func (e *Endpoint) Validate() error { + if e == nil { + return ErrNilEndpoint + } return nil // TODO: } diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go index 50de89c..5b98c6e 100644 --- a/pkg/op/endpoint_test.go +++ b/pkg/op/endpoint_test.go @@ -3,13 +3,14 @@ package op_test import ( "testing" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/stretchr/testify/require" ) func TestEndpoint_Path(t *testing.T) { tests := []struct { name string - e op.Endpoint + e *op.Endpoint want string }{ { @@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) { op.NewEndpointWithURL("/test", "http://test.com/test"), "/test", }, + { + "nil", + nil, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) { } tests := []struct { name string - e op.Endpoint + e *op.Endpoint args args want string }{ @@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) { args{"https://host"}, "https://test.com/test", }, + { + "nil", + nil, + args{"https://host"}, + "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -91,16 +103,19 @@ func TestEndpoint_Absolute(t *testing.T) { func TestEndpoint_Validate(t *testing.T) { tests := []struct { name string - e op.Endpoint - wantErr bool + e *op.Endpoint + wantErr error }{ - // TODO: Add test cases. + { + "nil", + nil, + op.ErrNilEndpoint, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.e.Validate(); (err != nil) != tt.wantErr { - t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) - } + err := tt.e.Validate() + require.ErrorIs(t, err, tt.wantErr) }) } } diff --git a/pkg/op/error.go b/pkg/op/error.go index acca4ab..272f85e 100644 --- a/pkg/op/error.go +++ b/pkg/op/error.go @@ -1,10 +1,14 @@ package op import ( + "context" + "errors" + "fmt" + "log/slog" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type ErrAuthRequest interface { @@ -13,34 +17,181 @@ type ErrAuthRequest interface { GetState() string } -func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder httphelper.Encoder) { +// LogAuthRequest is an optional interface, +// that allows logging AuthRequest fields. +// If the AuthRequest does not implement this interface, +// no details shall be printed to the logs. +type LogAuthRequest interface { + ErrAuthRequest + slog.LogValuer +} + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) { + e := oidc.DefaultToServerError(err, err.Error()) + logger := authorizer.Logger().With("oidc_error", e) + if authReq == nil { + logger.Log(r.Context(), e.LogLevel(), "auth request") http.Error(w, err.Error(), http.StatusBadRequest) return } - e := oidc.DefaultToServerError(err, err.Error()) + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting") http.Error(w, e.Description, http.StatusBadRequest) return } e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState + var responseMode oidc.ResponseMode + if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { + responseMode = rm.GetResponseMode() + } + url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder()) + if err != nil { + logger.ErrorContext(r.Context(), "auth response URL", "error", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + logger.Log(r.Context(), e.LogLevel(), "auth request") + http.Redirect(w, r, url, http.StatusFound) +} + +func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + e := oidc.DefaultToServerError(err, err.Error()) + status := http.StatusBadRequest + if e.ErrorType == oidc.InvalidClient { + status = http.StatusUnauthorized + } + logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e) + httphelper.MarshalJSONWithStatus(w, e, status) +} + +// TryErrorRedirect tries to handle an error by redirecting a client. +// If this attempt fails, an error is returned that must be returned +// to the client instead. +func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) { + e := oidc.DefaultToServerError(parent, parent.Error()) + logger = logger.With("oidc_error", e) + + if authReq == nil { + logger.Log(ctx, e.LogLevel(), "auth request") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + if logAuthReq, ok := authReq.(LogAuthRequest); ok { + logger = logger.With("auth_request", logAuthReq) + } + + if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() { + logger.Log(ctx, e.LogLevel(), "auth request: not redirecting") + return nil, AsStatusError(e, http.StatusBadRequest) + } + + e.State = authReq.GetState() + var sessionState string + authRequestSessionState, ok := authReq.(AuthRequestSessionState) + if ok { + sessionState = authRequestSessionState.GetSessionState() + } + e.SessionState = sessionState var responseMode oidc.ResponseMode if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok { responseMode = rm.GetResponseMode() } url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + logger.ErrorContext(ctx, "auth response URL", "error", err) + return nil, AsStatusError(err, http.StatusBadRequest) } - http.Redirect(w, r, url, http.StatusFound) + logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url) + return NewRedirect(url), nil } -func RequestError(w http.ResponseWriter, r *http.Request, err error) { - e := oidc.DefaultToServerError(err, err.Error()) - status := http.StatusBadRequest - if e.ErrorType == oidc.InvalidClient { - status = 401 - } - httphelper.MarshalJSONWithStatus(w, e, status) +// StatusError wraps an error with a HTTP status code. +// The status code is passed to the handler's writer. +type StatusError struct { + parent error + statusCode int +} + +// NewStatusError sets the parent and statusCode to a new StatusError. +// It is recommended for parent to be an [oidc.Error]. +// +// Typically implementations should only use this to signal something +// very specific, like an internal server error. +// If a returned error is not a StatusError, the framework +// will set a statusCode based on what the standard specifies, +// which is [http.StatusBadRequest] for most of the time. +// If the error encountered can described clearly with a [oidc.Error], +// do not use this function, as it might break standard rules! +func NewStatusError(parent error, statusCode int) StatusError { + return StatusError{ + parent: parent, + statusCode: statusCode, + } +} + +// AsStatusError unwraps a StatusError from err +// and returns it unmodified if found. +// If no StatuError was found, a new one is returned +// with statusCode set to it as a default. +func AsStatusError(err error, statusCode int) (target StatusError) { + if errors.As(err, &target) { + return target + } + return NewStatusError(err, statusCode) +} + +func (e StatusError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error()) +} + +func (e StatusError) Unwrap() error { + return e.parent +} + +func (e StatusError) Is(err error) bool { + var target StatusError + if !errors.As(err, &target) { + return false + } + return errors.Is(e.parent, target.parent) && + e.statusCode == target.statusCode +} + +// WriteError asserts for a [StatusError] containing an [oidc.Error]. +// If no `StatusError` is found, the status code will default to [http.StatusBadRequest]. +// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError]. +// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`, +// the status code will be set to [http.StatusInternalServerError] +func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) { + var statusError StatusError + if errors.As(err, &statusError) { + writeError(w, r, + oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()), + statusError.statusCode, logger, + ) + return + } + statusCode := http.StatusBadRequest + e := oidc.DefaultToServerError(err, err.Error()) + if e.ErrorType == oidc.ServerError { + statusCode = http.StatusInternalServerError + } + writeError(w, r, e, statusCode, logger) +} + +func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) { + logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode) + httphelper.MarshalJSONWithStatus(w, err, statusCode) } diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go new file mode 100644 index 0000000..9271cf1 --- /dev/null +++ b/pkg/op/error_test.go @@ -0,0 +1,682 @@ +package op + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/schema" +) + +func TestAuthRequestError(t *testing.T) { + type args struct { + authReq ErrAuthRequest + err error + } + tests := []struct { + name string + args args + wantCode int + wantHeaders map[string]string + wantBody string + wantLog string + }{ + { + name: "nil auth request", + args: args{ + authReq: nil, + err: io.ErrClosedPipe, + }, + wantCode: http.StatusBadRequest, + wantBody: "io: read/write on closed pipe\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "sign in\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantCode: http.StatusBadRequest, + wantBody: "oops\n", + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusBadRequest, + wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n", + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + err: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantCode: http.StatusFound, + wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"}, + wantLog: `{ + "level":"WARN", + "msg":"auth request", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + authorizer := &Provider{ + encoder: schema.NewEncoder(), + logger: slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode) + for key, wantHeader := range tt.wantHeaders { + gotHeader := res.Header.Get(key) + assert.Equalf(t, wantHeader, gotHeader, "header %q", key) + } + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.Equal(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestRequestError(t *testing.T) { + tests := []struct { + name string + err error + wantCode int + wantBody string + wantLog string + }{ + { + name: "server error", + err: io.ErrClosedPipe, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "time":"not", + "oidc_error":{ + "parent":"io: read/write on closed pipe", + "description":"io: read/write on closed pipe", + "type":"server_error"} + }`, + }, + { + name: "invalid client", + err: oidc.ErrInvalidClient().WithDescription("not good"), + wantCode: http.StatusUnauthorized, + wantBody: `{"error":"invalid_client", "error_description":"not good"}`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "time":"not", + "oidc_error":{ + "description":"not good", + "type":"invalid_client"} + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/path", nil) + RequestError(w, r, tt.err, logger) + + res := w.Result() + defer res.Body.Close() + + assert.Equal(t, tt.wantCode, res.StatusCode, "status code") + + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err, "read result body") + assert.JSONEq(t, tt.wantBody, string(gotBody), "result body") + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestTryErrorRedirect(t *testing.T) { + type args struct { + ctx context.Context + authReq ErrAuthRequest + parent error + } + tests := []struct { + name string + args args + want *Redirect + wantErr error + wantLog string + }{ + { + name: "nil auth request", + args: args{ + ctx: context.Background(), + authReq: nil, + parent: io.ErrClosedPipe, + }, + wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest), + wantLog: `{ + "level":"ERROR", + "msg":"auth request", + "time":"not", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + } + }`, + }, + { + name: "auth request, no redirect URI", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request, redirect disabled", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), + }, + wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest), + wantLog: `{ + "level":"WARN", + "msg":"auth request: not redirecting", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"oops", + "type":"invalid_request", + "redirect_disabled":true + } + }`, + }, + { + name: "auth request, url parse error", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "can't parse this!\n", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + wantErr: func() error { + //lint:ignore SA1007 just recreating the error for testing + _, err := url.Parse("can't parse this!\n") + err = oidc.ErrServerError().WithParent(err) + return NewStatusError(err, http.StatusBadRequest) + }(), + wantLog: `{ + "level":"ERROR", + "msg":"auth response URL", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"can't parse this!\n", + "response_type":"responseType", + "scopes":"a b" + }, + "error":{ + "type":"server_error", + "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + } + }`, + }, + { + name: "auth request redirect", + args: args{ + ctx: context.Background(), + authReq: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"a", "b"}, + ResponseType: "responseType", + ClientID: "123", + RedirectURI: "http://example.com/callback", + State: "state1", + ResponseMode: oidc.ResponseModeQuery, + }, + parent: oidc.ErrInteractionRequired().WithDescription("sign in"), + }, + want: &Redirect{ + Header: make(http.Header), + URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1", + }, + wantLog: `{ + "level":"WARN", + "msg":"auth request redirect", + "time":"not", + "auth_request":{ + "client_id":"123", + "redirect_uri":"http://example.com/callback", + "response_type":"responseType", + "scopes":"a b" + }, + "oidc_error":{ + "description":"sign in", + "type":"interaction_required" + }, + "url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + encoder := schema.NewEncoder() + + got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + + gotLog := logOut.String() + t.Log(gotLog) + assert.JSONEq(t, tt.wantLog, gotLog, "log output") + }) + } +} + +func TestNewStatusError(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + + want := "Internal Server Error: io: read/write on closed pipe" + got := fmt.Sprint(err) + assert.Equal(t, want, got) +} + +func TestAsStatusError(t *testing.T) { + type args struct { + err error + statusCode int + } + tests := []struct { + name string + args args + want string + }{ + { + name: "already status error", + args: args{ + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + statusCode: http.StatusBadRequest, + }, + want: "Internal Server Error: io: read/write on closed pipe", + }, + { + name: "oidc error", + args: args{ + err: oidc.ErrAcrInvalid, + statusCode: http.StatusBadRequest, + }, + want: "Bad Request: acr is invalid", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := AsStatusError(tt.args.err, tt.args.statusCode) + got := fmt.Sprint(err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestStatusError_Unwrap(t *testing.T) { + err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + require.ErrorIs(t, err, io.ErrClosedPipe) +} + +func TestStatusError_Is(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil error", + args: args{err: nil}, + want: false, + }, + { + name: "other error", + args: args{err: io.EOF}, + want: false, + }, + { + name: "other parent", + args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)}, + want: false, + }, + { + name: "other status", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)}, + want: false, + }, + { + name: "same", + args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)}, + want: true, + }, + { + name: "wrapped", + args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError) + if got := e.Is(tt.args.err); got != tt.want { + t.Errorf("StatusError.Is() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteError(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + wantBody string + wantLog string + }{ + { + name: "not a status or oidc error", + err: io.ErrClosedPipe, + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "status_code":500, + "time":"not" + }`, + }, + { + name: "status error w/o oidc", + err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "error":"server_error", + "error_description":"io: read/write on closed pipe" + }`, + wantLog: `{ + "level":"ERROR", + "msg":"request error", + "oidc_error":{ + "description":"io: read/write on closed pipe", + "parent":"io: read/write on closed pipe", + "type":"server_error" + }, + "status_code":500, + "time":"not" + }`, + }, + { + name: "oidc error w/o status", + err: oidc.ErrInvalidRequest().WithDescription("oops"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "error":"invalid_request", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"invalid_request" + }, + "status_code":400, + "time":"not" + }`, + }, + { + name: "status with oidc error", + err: NewStatusError( + oidc.ErrUnauthorizedClient().WithDescription("oops"), + http.StatusUnauthorized, + ), + wantStatus: http.StatusUnauthorized, + wantBody: `{ + "error":"unauthorized_client", + "error_description":"oops" + }`, + wantLog: `{ + "level":"WARN", + "msg":"request error", + "oidc_error":{ + "description":"oops", + "type":"unauthorized_client" + }, + "status_code":401, + "time":"not" + }`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logOut := new(strings.Builder) + logger := slog.New( + slog.NewJSONHandler(logOut, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }).WithAttrs([]slog.Attr{slog.String("time", "not")}), + ) + r := httptest.NewRequest("GET", "/target", nil) + w := httptest.NewRecorder() + + WriteError(w, r, tt.err, logger) + res := w.Result() + assert.Equal(t, tt.wantStatus, res.StatusCode, "status code") + gotBody, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, tt.wantBody, string(gotBody), "body") + assert.JSONEq(t, tt.wantLog, logOut.String()) + }) + } +} diff --git a/pkg/op/form_post.html.tmpl b/pkg/op/form_post.html.tmpl new file mode 100644 index 0000000..7bc9ab3 --- /dev/null +++ b/pkg/op/form_post.html.tmpl @@ -0,0 +1,14 @@ + + + + +
+{{with .Params.state}}{{end}} +{{with .Params.code}}{{end}} +{{with .Params.id_token}}{{end}} +{{with .Params.access_token}}{{end}} +{{with .Params.token_type}}{{end}} +{{with .Params.expires_in}}{{end}} +
+ + \ No newline at end of file diff --git a/pkg/op/keys.go b/pkg/op/keys.go index 239ecbd..97e400b 100644 --- a/pkg/op/keys.go +++ b/pkg/op/keys.go @@ -4,9 +4,9 @@ import ( "context" "net/http" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" ) type KeyProvider interface { @@ -20,6 +20,10 @@ func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) { } func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { + ctx, span := tracer.Start(r.Context(), "Keys") + r = r.WithContext(ctx) + defer span.End() + keySet, err := k.KeySet(r.Context()) if err != nil { httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError) diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go index 2e56b78..9c80878 100644 --- a/pkg/op/keys_test.go +++ b/pkg/op/keys_test.go @@ -7,13 +7,13 @@ import ( "net/http/httptest" "testing" + jose "github.com/go-jose/go-jose/v4" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "gopkg.in/square/go-jose.v2" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" - "github.com/zitadel/oidc/v2/pkg/op/mock" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock" ) func TestKeys(t *testing.T) { diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go index cc913ee..56b28e0 100644 --- a/pkg/op/mock/authorizer.mock.go +++ b/pkg/op/mock/authorizer.mock.go @@ -1,16 +1,17 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Authorizer) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Authorizer) // Package mock is a generated GoMock package. package mock import ( context "context" + slog "log/slog" reflect "reflect" + http "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" gomock "github.com/golang/mock/gomock" - http "github.com/zitadel/oidc/v2/pkg/http" - op "github.com/zitadel/oidc/v2/pkg/op" ) // MockAuthorizer is a mock of Authorizer interface. @@ -79,10 +80,10 @@ func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { } // IDTokenHintVerifier mocks base method. -func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) op.IDTokenHintVerifier { +func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0) - ret0, _ := ret[0].(op.IDTokenHintVerifier) + ret0, _ := ret[0].(*op.IDTokenHintVerifier) return ret0 } @@ -92,6 +93,20 @@ func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0) } +// Logger mocks base method. +func (m *MockAuthorizer) Logger() *slog.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*slog.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger)) +} + // RequestObjectSupported mocks base method. func (m *MockAuthorizer) RequestObjectSupported() bool { m.ctrl.T.Helper() diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go index 3f1d525..73c4154 100644 --- a/pkg/op/mock/authorizer.mock.impl.go +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -4,12 +4,12 @@ import ( "context" "testing" + jose "github.com/go-jose/go-jose/v4" "github.com/golang/mock/gomock" - "github.com/gorilla/schema" - "gopkg.in/square/go-jose.v2" + "github.com/zitadel/schema" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) func NewAuthorizer(t *testing.T) op.Authorizer { @@ -49,7 +49,7 @@ func ExpectEncoder(a op.Authorizer) { func ExpectVerifier(a op.Authorizer, t *testing.T) { mockA := a.(*MockAuthorizer) mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn( - func() op.IDTokenHintVerifier { + func() *op.IDTokenHintVerifier { return op.NewIDTokenHintVerifier("", nil) }) } diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go index 36df84a..e2a5e85 100644 --- a/pkg/op/mock/client.go +++ b/pkg/op/mock/client.go @@ -5,8 +5,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) func NewClient(t *testing.T) op.Client { diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go index e3d19fb..93eca67 100644 --- a/pkg/op/mock/client.mock.go +++ b/pkg/op/mock/client.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Client) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Client) // Package mock is a generated GoMock package. package mock @@ -8,9 +8,9 @@ import ( reflect "reflect" time "time" + oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" ) // MockClient is a mock of Client interface. diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go index fe7d4da..bf51035 100644 --- a/pkg/op/mock/configuration.mock.go +++ b/pkg/op/mock/configuration.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Configuration) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Configuration) // Package mock is a generated GoMock package. package mock @@ -8,8 +8,8 @@ import ( http "net/http" reflect "reflect" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" language "golang.org/x/text/language" ) @@ -65,10 +65,10 @@ func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gom } // AuthorizationEndpoint mocks base method. -func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -78,6 +78,48 @@ func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint)) } +// BackChannelLogoutSessionSupported mocks base method. +func (m *MockConfiguration) BackChannelLogoutSessionSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BackChannelLogoutSessionSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// BackChannelLogoutSessionSupported indicates an expected call of BackChannelLogoutSessionSupported. +func (mr *MockConfigurationMockRecorder) BackChannelLogoutSessionSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSessionSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSessionSupported)) +} + +// BackChannelLogoutSupported mocks base method. +func (m *MockConfiguration) BackChannelLogoutSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BackChannelLogoutSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// BackChannelLogoutSupported indicates an expected call of BackChannelLogoutSupported. +func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported)) +} + +// CheckSessionIframe mocks base method. +func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckSessionIframe") + ret0, _ := ret[0].(*op.Endpoint) + return ret0 +} + +// CheckSessionIframe indicates an expected call of CheckSessionIframe. +func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe)) +} + // CodeMethodS256Supported mocks base method. func (m *MockConfiguration) CodeMethodS256Supported() bool { m.ctrl.T.Helper() @@ -107,10 +149,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call { } // DeviceAuthorizationEndpoint mocks base method. -func (m *MockConfiguration) DeviceAuthorizationEndpoint() op.Endpoint { +func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -121,10 +163,10 @@ func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.C } // EndSessionEndpoint mocks base method. -func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint { +func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EndSessionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -233,10 +275,10 @@ func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSup } // IntrospectionEndpoint mocks base method. -func (m *MockConfiguration) IntrospectionEndpoint() op.Endpoint { +func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IntrospectionEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -275,10 +317,10 @@ func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *go } // KeysEndpoint mocks base method. -func (m *MockConfiguration) KeysEndpoint() op.Endpoint { +func (m *MockConfiguration) KeysEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "KeysEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -331,10 +373,10 @@ func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSuppor } // RevocationEndpoint mocks base method. -func (m *MockConfiguration) RevocationEndpoint() op.Endpoint { +func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevocationEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -373,10 +415,10 @@ func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call { } // TokenEndpoint mocks base method. -func (m *MockConfiguration) TokenEndpoint() op.Endpoint { +func (m *MockConfiguration) TokenEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TokenEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } @@ -401,10 +443,10 @@ func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported } // UserinfoEndpoint mocks base method. -func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { +func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UserinfoEndpoint") - ret0, _ := ret[0].(op.Endpoint) + ret0, _ := ret[0].(*op.Endpoint) return ret0 } diff --git a/pkg/op/mock/discovery.mock.go b/pkg/op/mock/discovery.mock.go index 0c78d52..c85f91b 100644 --- a/pkg/op/mock/discovery.mock.go +++ b/pkg/op/mock/discovery.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: DiscoverStorage) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: DiscoverStorage) // Package mock is a generated GoMock package. package mock @@ -8,8 +8,8 @@ import ( context "context" reflect "reflect" + jose "github.com/go-jose/go-jose/v4" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockDiscoverStorage is a mock of DiscoverStorage interface. diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go index ca288d2..3d58ab7 100644 --- a/pkg/op/mock/generate.go +++ b/pkg/op/mock/generate.go @@ -1,10 +1,11 @@ package mock //go:generate go install github.com/golang/mock/mockgen@v1.6.0 -//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v2/pkg/op Storage -//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v2/pkg/op Authorizer -//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v2/pkg/op Client -//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v2/pkg/op Configuration -//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v2/pkg/op DiscoverStorage -//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v2/pkg/op SigningKey,Key -//go:generate mockgen -package mock -destination ./key.mock.go github.com/zitadel/oidc/v2/pkg/op KeyProvider +//go:generate mockgen -package mock -destination ./storage.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Storage +//go:generate mockgen -package mock -destination ./authorizer.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Authorizer +//go:generate mockgen -package mock -destination ./client.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Client +//go:generate mockgen -package mock -destination ./glob.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op HasRedirectGlobs +//go:generate mockgen -package mock -destination ./configuration.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Configuration +//go:generate mockgen -package mock -destination ./discovery.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op DiscoverStorage +//go:generate mockgen -package mock -destination ./signer.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op SigningKey,Key +//go:generate mockgen -package mock -destination ./key.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op KeyProvider diff --git a/pkg/op/mock/glob.go b/pkg/op/mock/glob.go new file mode 100644 index 0000000..8149c8f --- /dev/null +++ b/pkg/op/mock/glob.go @@ -0,0 +1,24 @@ +package mock + +import ( + "testing" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + gomock "github.com/golang/mock/gomock" +) + +func NewHasRedirectGlobs(t *testing.T) op.HasRedirectGlobs { + return NewMockHasRedirectGlobs(gomock.NewController(t)) +} + +func NewHasRedirectGlobsWithConfig(t *testing.T, uri []string, appType op.ApplicationType, responseTypes []oidc.ResponseType, devMode bool) op.HasRedirectGlobs { + c := NewHasRedirectGlobs(t) + m := c.(*MockHasRedirectGlobs) + m.EXPECT().RedirectURIs().AnyTimes().Return(uri) + m.EXPECT().RedirectURIGlobs().AnyTimes().Return(uri) + m.EXPECT().ApplicationType().AnyTimes().Return(appType) + m.EXPECT().ResponseTypes().AnyTimes().Return(responseTypes) + m.EXPECT().DevMode().AnyTimes().Return(devMode) + return c +} diff --git a/pkg/op/mock/glob.mock.go b/pkg/op/mock/glob.mock.go new file mode 100644 index 0000000..ebdc333 --- /dev/null +++ b/pkg/op/mock/glob.mock.go @@ -0,0 +1,289 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: HasRedirectGlobs) + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + time "time" + + oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + gomock "github.com/golang/mock/gomock" +) + +// MockHasRedirectGlobs is a mock of HasRedirectGlobs interface. +type MockHasRedirectGlobs struct { + ctrl *gomock.Controller + recorder *MockHasRedirectGlobsMockRecorder +} + +// MockHasRedirectGlobsMockRecorder is the mock recorder for MockHasRedirectGlobs. +type MockHasRedirectGlobsMockRecorder struct { + mock *MockHasRedirectGlobs +} + +// NewMockHasRedirectGlobs creates a new mock instance. +func NewMockHasRedirectGlobs(ctrl *gomock.Controller) *MockHasRedirectGlobs { + mock := &MockHasRedirectGlobs{ctrl: ctrl} + mock.recorder = &MockHasRedirectGlobsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockHasRedirectGlobs) EXPECT() *MockHasRedirectGlobsMockRecorder { + return m.recorder +} + +// AccessTokenType mocks base method. +func (m *MockHasRedirectGlobs) AccessTokenType() op.AccessTokenType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenType") + ret0, _ := ret[0].(op.AccessTokenType) + return ret0 +} + +// AccessTokenType indicates an expected call of AccessTokenType. +func (mr *MockHasRedirectGlobsMockRecorder) AccessTokenType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AccessTokenType)) +} + +// ApplicationType mocks base method. +func (m *MockHasRedirectGlobs) ApplicationType() op.ApplicationType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplicationType") + ret0, _ := ret[0].(op.ApplicationType) + return ret0 +} + +// ApplicationType indicates an expected call of ApplicationType. +func (mr *MockHasRedirectGlobsMockRecorder) ApplicationType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ApplicationType)) +} + +// AuthMethod mocks base method. +func (m *MockHasRedirectGlobs) AuthMethod() oidc.AuthMethod { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthMethod") + ret0, _ := ret[0].(oidc.AuthMethod) + return ret0 +} + +// AuthMethod indicates an expected call of AuthMethod. +func (mr *MockHasRedirectGlobsMockRecorder) AuthMethod() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethod", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AuthMethod)) +} + +// ClockSkew mocks base method. +func (m *MockHasRedirectGlobs) ClockSkew() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClockSkew") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// ClockSkew indicates an expected call of ClockSkew. +func (mr *MockHasRedirectGlobsMockRecorder) ClockSkew() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClockSkew", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ClockSkew)) +} + +// DevMode mocks base method. +func (m *MockHasRedirectGlobs) DevMode() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DevMode") + ret0, _ := ret[0].(bool) + return ret0 +} + +// DevMode indicates an expected call of DevMode. +func (mr *MockHasRedirectGlobsMockRecorder) DevMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DevMode", reflect.TypeOf((*MockHasRedirectGlobs)(nil).DevMode)) +} + +// GetID mocks base method. +func (m *MockHasRedirectGlobs) GetID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID. +func (mr *MockHasRedirectGlobsMockRecorder) GetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GetID)) +} + +// GrantTypes mocks base method. +func (m *MockHasRedirectGlobs) GrantTypes() []oidc.GrantType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GrantTypes") + ret0, _ := ret[0].([]oidc.GrantType) + return ret0 +} + +// GrantTypes indicates an expected call of GrantTypes. +func (mr *MockHasRedirectGlobsMockRecorder) GrantTypes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GrantTypes)) +} + +// IDTokenLifetime mocks base method. +func (m *MockHasRedirectGlobs) IDTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// IDTokenLifetime indicates an expected call of IDTokenLifetime. +func (mr *MockHasRedirectGlobsMockRecorder) IDTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenLifetime)) +} + +// IDTokenUserinfoClaimsAssertion mocks base method. +func (m *MockHasRedirectGlobs) IDTokenUserinfoClaimsAssertion() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenUserinfoClaimsAssertion") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IDTokenUserinfoClaimsAssertion indicates an expected call of IDTokenUserinfoClaimsAssertion. +func (mr *MockHasRedirectGlobsMockRecorder) IDTokenUserinfoClaimsAssertion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenUserinfoClaimsAssertion", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenUserinfoClaimsAssertion)) +} + +// IsScopeAllowed mocks base method. +func (m *MockHasRedirectGlobs) IsScopeAllowed(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsScopeAllowed", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsScopeAllowed indicates an expected call of IsScopeAllowed. +func (mr *MockHasRedirectGlobsMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IsScopeAllowed), arg0) +} + +// LoginURL mocks base method. +func (m *MockHasRedirectGlobs) LoginURL(arg0 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoginURL", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// LoginURL indicates an expected call of LoginURL. +func (mr *MockHasRedirectGlobsMockRecorder) LoginURL(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockHasRedirectGlobs)(nil).LoginURL), arg0) +} + +// PostLogoutRedirectURIGlobs mocks base method. +func (m *MockHasRedirectGlobs) PostLogoutRedirectURIGlobs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostLogoutRedirectURIGlobs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// PostLogoutRedirectURIGlobs indicates an expected call of PostLogoutRedirectURIGlobs. +func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIGlobs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIGlobs)) +} + +// PostLogoutRedirectURIs mocks base method. +func (m *MockHasRedirectGlobs) PostLogoutRedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PostLogoutRedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs. +func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIs)) +} + +// RedirectURIGlobs mocks base method. +func (m *MockHasRedirectGlobs) RedirectURIGlobs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RedirectURIGlobs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RedirectURIGlobs indicates an expected call of RedirectURIGlobs. +func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIGlobs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIGlobs)) +} + +// RedirectURIs mocks base method. +func (m *MockHasRedirectGlobs) RedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RedirectURIs indicates an expected call of RedirectURIs. +func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIs)) +} + +// ResponseTypes mocks base method. +func (m *MockHasRedirectGlobs) ResponseTypes() []oidc.ResponseType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResponseTypes") + ret0, _ := ret[0].([]oidc.ResponseType) + return ret0 +} + +// ResponseTypes indicates an expected call of ResponseTypes. +func (mr *MockHasRedirectGlobsMockRecorder) ResponseTypes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ResponseTypes)) +} + +// RestrictAdditionalAccessTokenScopes mocks base method. +func (m *MockHasRedirectGlobs) RestrictAdditionalAccessTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes. +func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalAccessTokenScopes)) +} + +// RestrictAdditionalIdTokenScopes mocks base method. +func (m *MockHasRedirectGlobs) RestrictAdditionalIdTokenScopes() func([]string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes") + ret0, _ := ret[0].(func([]string) []string) + return ret0 +} + +// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes. +func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalIdTokenScopes)) +} diff --git a/pkg/op/mock/key.mock.go b/pkg/op/mock/key.mock.go index 8831651..d9ee857 100644 --- a/pkg/op/mock/key.mock.go +++ b/pkg/op/mock/key.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: KeyProvider) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: KeyProvider) // Package mock is a generated GoMock package. package mock @@ -8,8 +8,8 @@ import ( context "context" reflect "reflect" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" gomock "github.com/golang/mock/gomock" - op "github.com/zitadel/oidc/v2/pkg/op" ) // MockKeyProvider is a mock of KeyProvider interface. diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go index 78c0efe..751ce60 100644 --- a/pkg/op/mock/signer.mock.go +++ b/pkg/op/mock/signer.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: SigningKey,Key) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: SigningKey,Key) // Package mock is a generated GoMock package. package mock @@ -7,8 +7,8 @@ package mock import ( reflect "reflect" + jose "github.com/go-jose/go-jose/v4" gomock "github.com/golang/mock/gomock" - jose "gopkg.in/square/go-jose.v2" ) // MockSigningKey is a mock of SigningKey interface. diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go index 85afb2a..0df9830 100644 --- a/pkg/op/mock/storage.mock.go +++ b/pkg/op/mock/storage.mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/oidc/v2/pkg/op (interfaces: Storage) +// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Storage) // Package mock is a generated GoMock package. package mock @@ -9,10 +9,10 @@ import ( reflect "reflect" time "time" + oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + jose "github.com/go-jose/go-jose/v4" gomock "github.com/golang/mock/gomock" - oidc "github.com/zitadel/oidc/v2/pkg/oidc" - op "github.com/zitadel/oidc/v2/pkg/op" - jose "gopkg.in/square/go-jose.v2" ) // MockStorage is a mock of Storage interface. diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go index 9269f89..96e08a9 100644 --- a/pkg/op/mock/storage.mock.impl.go +++ b/pkg/op/mock/storage.mock.impl.go @@ -8,8 +8,8 @@ import ( "github.com/golang/mock/gomock" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) func NewStorage(t *testing.T) op.Storage { diff --git a/pkg/op/op.go b/pkg/op/op.go index ecb753e..76c2c89 100644 --- a/pkg/op/op.go +++ b/pkg/op/op.go @@ -3,17 +3,19 @@ package op import ( "context" "fmt" + "log/slog" "net/http" "time" - "github.com/gorilla/mux" - "github.com/gorilla/schema" + "github.com/go-chi/chi/v5" + jose "github.com/go-jose/go-jose/v4" "github.com/rs/cors" + "github.com/zitadel/schema" + "go.opentelemetry.io/otel" "golang.org/x/text/language" - "gopkg.in/square/go-jose.v2" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) const ( @@ -31,7 +33,7 @@ const ( ) var ( - DefaultEndpoints = &endpoints{ + DefaultEndpoints = &Endpoints{ Authorization: NewEndpoint(defaultAuthorizationEndpoint), Token: NewEndpoint(defaultTokenEndpoint), Introspection: NewEndpoint(defaultIntrospectEndpoint), @@ -42,6 +44,33 @@ var ( DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint), } + DefaultSupportedClaims = []string{ + "sub", + "aud", + "exp", + "iat", + "iss", + "auth_time", + "nonce", + "acr", + "amr", + "c_hash", + "at_hash", + "act", + "scopes", + "client_id", + "azp", + "preferred_username", + "name", + "family_name", + "given_name", + "locale", + "email", + "email_verified", + "phone_number", + "phone_number_verified", + } + defaultCORSOptions = cors.Options{ AllowCredentials: true, AllowedHeaders: []string{ @@ -67,30 +96,46 @@ var ( } ) +var tracer = otel.Tracer("github.com/zitadel/oidc/pkg/op") + type OpenIDProvider interface { + http.Handler Configuration Storage() Storage Decoder() httphelper.Decoder Encoder() httphelper.Encoder - IDTokenHintVerifier(context.Context) IDTokenHintVerifier - AccessTokenVerifier(context.Context) AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier Crypto() Crypto DefaultLogoutRedirectURI() string Probes() []ProbesFn + Logger() *slog.Logger + + // Deprecated: Provider now implements http.Handler directly. HttpHandler() http.Handler } type HttpInterceptor func(http.Handler) http.Handler -func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) *mux.Router { - router := mux.NewRouter() - router.Use(cors.New(defaultCORSOptions).Handler) +type corsOptioner interface { + CORSOptions() *cors.Options +} + +func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router { + router := chi.NewRouter() + if co, ok := o.(corsOptioner); ok { + if opts := co.CORSOptions(); opts != nil { + router.Use(cors.New(*opts).Handler) + } + } else { + router.Use(cors.New(defaultCORSOptions).Handler) + } router.Use(intercept(o.IssuerFromRequest, interceptors...)) router.HandleFunc(healthEndpoint, healthHandler) router.HandleFunc(readinessEndpoint, readyHandler(o.Probes())) router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage())) router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o)) - router.NewRoute().Path(authCallbackPath(o)).Queries("id", "{id}").HandlerFunc(authorizeCallbackHandler(o)) + router.HandleFunc(authCallbackPath(o), AuthorizeCallbackHandler(o)) router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o)) router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o)) router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o)) @@ -113,27 +158,32 @@ func authCallbackPath(o OpenIDProvider) string { } type Config struct { - CryptoKey [32]byte - DefaultLogoutRedirectURI string - CodeMethodS256 bool - AuthMethodPost bool - AuthMethodPrivateKeyJWT bool - GrantTypeRefreshToken bool - RequestObjectSupported bool - SupportedUILocales []language.Tag - DeviceAuthorization DeviceAuthorizationConfig + CryptoKey [32]byte + DefaultLogoutRedirectURI string + CodeMethodS256 bool + AuthMethodPost bool + AuthMethodPrivateKeyJWT bool + GrantTypeRefreshToken bool + RequestObjectSupported bool + SupportedUILocales []language.Tag + SupportedClaims []string + SupportedScopes []string + DeviceAuthorization DeviceAuthorizationConfig + BackChannelLogoutSupported bool + BackChannelLogoutSessionSupported bool } -type endpoints struct { - Authorization Endpoint - Token Endpoint - Introspection Endpoint - Userinfo Endpoint - Revocation Endpoint - EndSession Endpoint - CheckSessionIframe Endpoint - JwksURI Endpoint - DeviceAuthorization Endpoint +// Endpoints defines endpoint routes. +type Endpoints struct { + Authorization *Endpoint + Token *Endpoint + Introspection *Endpoint + Userinfo *Endpoint + Revocation *Endpoint + EndSession *Endpoint + CheckSessionIframe *Endpoint + JwksURI *Endpoint + DeviceAuthorization *Endpoint } // NewOpenIDProvider creates a provider. The provider provides (with HttpHandler()) @@ -157,20 +207,62 @@ type endpoints struct { // Successful logins should mark the request as authorized and redirect back to to // op.AuthCallbackURL(provider) which is probably /callback. On the redirect back // to the AuthCallbackURL, the request id should be passed as the "id" parameter. +// +// Deprecated: use [NewProvider] with an issuer function direct. func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { - return newProvider(config, storage, StaticIssuer(issuer), opOpts...) + return NewProvider(config, storage, StaticIssuer(issuer), opOpts...) } +// NewForwardedOpenIDProvider tries to establishes the issuer from the request Host. +// +// Deprecated: use [NewProvider] with an issuer function direct. func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { - return newProvider(config, storage, IssuerFromHost(path), opOpts...) + return NewProvider(config, storage, IssuerFromHost(path), opOpts...) } -func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) { +// NewForwardedOpenIDProvider tries to establish the Issuer from a Forwarded request header, if it is set. +// See [IssuerFromForwardedOrHost] for details. +// +// Deprecated: use [NewProvider] with an issuer function direct. +func NewForwardedOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) { + return NewProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...) +} + +// NewProvider creates a provider with a router on it's embedded http.Handler. +// Issuer is a function that must return the issuer on every request. +// Typically [StaticIssuer], [IssuerFromHost] or [IssuerFromForwardedOrHost] can be used. +// +// The router handles a suite of endpoints (some paths can be overridden): +// +// /healthz +// /ready +// /.well-known/openid-configuration +// /oauth/token +// /oauth/introspect +// /callback +// /authorize +// /userinfo +// /revoke +// /end_session +// /keys +// /device_authorization +// +// This does not include login. Login is handled with a redirect that includes the +// request ID. The redirect for logins is specified per-client by Client.LoginURL(). +// Successful logins should mark the request as authorized and redirect back to to +// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back +// to the AuthCallbackURL, the request id should be passed as the "id" parameter. +func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) { + keySet := &OpenIDKeySet{storage} o := &Provider{ - config: config, - storage: storage, - endpoints: DefaultEndpoints, - timer: make(<-chan time.Time), + config: config, + storage: storage, + accessTokenKeySet: keySet, + idTokenHinKeySet: keySet, + endpoints: DefaultEndpoints, + timer: make(<-chan time.Time), + corsOpts: &defaultCORSOptions, + logger: slog.Default(), } for _, optFunc := range opOpts { @@ -183,37 +275,32 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR if err != nil { return nil, err } - - o.httpHandler = CreateRouter(o, o.interceptors...) - + o.Handler = CreateRouter(o, o.interceptors...) o.decoder = schema.NewDecoder() o.decoder.IgnoreUnknownKeys(true) - o.encoder = oidc.NewEncoder() - o.crypto = NewAESCrypto(config.CryptoKey) - - // Avoid potential race conditions by calling these early - _ = o.openIDKeySet() // sets keySet - return o, nil } type Provider struct { + http.Handler config *Config issuer IssuerFromRequest insecure bool - endpoints *endpoints + endpoints *Endpoints storage Storage - keySet *openIDKeySet + accessTokenKeySet oidc.KeySet + idTokenHinKeySet oidc.KeySet crypto Crypto - httpHandler http.Handler decoder *schema.Decoder encoder *schema.Encoder interceptors []HttpInterceptor timer <-chan time.Time accessTokenVerifierOpts []AccessTokenVerifierOpt idTokenHintVerifierOpts []IDTokenHintVerifierOpt + corsOpts *cors.Options + logger *slog.Logger } func (o *Provider) IssuerFromRequest(r *http.Request) string { @@ -224,35 +311,39 @@ func (o *Provider) Insecure() bool { return o.insecure } -func (o *Provider) AuthorizationEndpoint() Endpoint { +func (o *Provider) AuthorizationEndpoint() *Endpoint { return o.endpoints.Authorization } -func (o *Provider) TokenEndpoint() Endpoint { +func (o *Provider) TokenEndpoint() *Endpoint { return o.endpoints.Token } -func (o *Provider) IntrospectionEndpoint() Endpoint { +func (o *Provider) IntrospectionEndpoint() *Endpoint { return o.endpoints.Introspection } -func (o *Provider) UserinfoEndpoint() Endpoint { +func (o *Provider) UserinfoEndpoint() *Endpoint { return o.endpoints.Userinfo } -func (o *Provider) RevocationEndpoint() Endpoint { +func (o *Provider) RevocationEndpoint() *Endpoint { return o.endpoints.Revocation } -func (o *Provider) EndSessionEndpoint() Endpoint { +func (o *Provider) EndSessionEndpoint() *Endpoint { return o.endpoints.EndSession } -func (o *Provider) DeviceAuthorizationEndpoint() Endpoint { +func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint { return o.endpoints.DeviceAuthorization } -func (o *Provider) KeysEndpoint() Endpoint { +func (o *Provider) CheckSessionIframe() *Endpoint { + return o.endpoints.CheckSessionIframe +} + +func (o *Provider) KeysEndpoint() *Endpoint { return o.endpoints.JwksURI } @@ -327,6 +418,14 @@ func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig { return o.config.DeviceAuthorization } +func (o *Provider) BackChannelLogoutSupported() bool { + return o.config.BackChannelLogoutSupported +} + +func (o *Provider) BackChannelLogoutSessionSupported() bool { + return o.config.BackChannelLogoutSessionSupported +} + func (o *Provider) Storage() Storage { return o.storage } @@ -339,23 +438,16 @@ func (o *Provider) Encoder() httphelper.Encoder { return o.encoder } -func (o *Provider) IDTokenHintVerifier(ctx context.Context) IDTokenHintVerifier { - return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...) +func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier { + return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.idTokenHinKeySet, o.idTokenHintVerifierOpts...) } -func (o *Provider) JWTProfileVerifier(ctx context.Context) JWTProfileVerifier { +func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier { return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second) } -func (o *Provider) AccessTokenVerifier(ctx context.Context) AccessTokenVerifier { - return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...) -} - -func (o *Provider) openIDKeySet() oidc.KeySet { - if o.keySet == nil { - o.keySet = &openIDKeySet{o.Storage()} - } - return o.keySet +func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier { + return NewAccessTokenVerifier(IssuerFromContext(ctx), o.accessTokenKeySet, o.accessTokenVerifierOpts...) } func (o *Provider) Crypto() Crypto { @@ -372,17 +464,26 @@ func (o *Provider) Probes() []ProbesFn { } } -func (o *Provider) HttpHandler() http.Handler { - return o.httpHandler +func (o *Provider) CORSOptions() *cors.Options { + return o.corsOpts } -type openIDKeySet struct { +func (o *Provider) Logger() *slog.Logger { + return o.logger +} + +// Deprecated: Provider now implements http.Handler directly. +func (o *Provider) HttpHandler() http.Handler { + return o +} + +type OpenIDKeySet struct { Storage } // VerifySignature implements the oidc.KeySet interface // providing an implementation for the keys stored in the OP Storage interface -func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { +func (o *OpenIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { keySet, err := o.Storage.KeySet(ctx) if err != nil { return nil, fmt.Errorf("error fetching keys: %w", err) @@ -406,7 +507,7 @@ func WithAllowInsecure() Option { } } -func WithCustomAuthEndpoint(endpoint Endpoint) Option { +func WithCustomAuthEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -416,7 +517,7 @@ func WithCustomAuthEndpoint(endpoint Endpoint) Option { } } -func WithCustomTokenEndpoint(endpoint Endpoint) Option { +func WithCustomTokenEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -426,7 +527,7 @@ func WithCustomTokenEndpoint(endpoint Endpoint) Option { } } -func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { +func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -436,7 +537,7 @@ func WithCustomIntrospectionEndpoint(endpoint Endpoint) Option { } } -func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { +func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -446,7 +547,7 @@ func WithCustomUserinfoEndpoint(endpoint Endpoint) Option { } } -func WithCustomRevocationEndpoint(endpoint Endpoint) Option { +func WithCustomRevocationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -456,7 +557,7 @@ func WithCustomRevocationEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { +func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -466,7 +567,7 @@ func WithCustomEndSessionEndpoint(endpoint Endpoint) Option { } } -func WithCustomKeysEndpoint(endpoint Endpoint) Option { +func WithCustomKeysEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { if err := endpoint.Validate(); err != nil { return err @@ -476,8 +577,26 @@ func WithCustomKeysEndpoint(endpoint Endpoint) Option { } } -func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys Endpoint) Option { +func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option { return func(o *Provider) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.DeviceAuthorization = endpoint + return nil + } +} + +// WithCustomEndpoints sets multiple endpoints at once. +// Non of the endpoints may be nil, or an error will +// be returned when the Option used by the Provider. +func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option { + return func(o *Provider) error { + for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} { + if err := e.Validate(); err != nil { + return err + } + } o.endpoints.Authorization = auth o.endpoints.Token = token o.endpoints.Userinfo = userInfo @@ -495,6 +614,15 @@ func WithHttpInterceptors(interceptors ...HttpInterceptor) Option { } } +// WithAccessTokenKeySet allows passing a KeySet with public keys for Access Token verification. +// The default KeySet uses the [Storage] interface +func WithAccessTokenKeySet(keySet oidc.KeySet) Option { + return func(o *Provider) error { + o.accessTokenKeySet = keySet + return nil + } +} + func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option { return func(o *Provider) error { o.accessTokenVerifierOpts = opts @@ -502,6 +630,15 @@ func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option { } } +// WithIDTokenHintKeySet allows passing a KeySet with public keys for ID Token Hint verification. +// The default KeySet uses the [Storage] interface. +func WithIDTokenHintKeySet(keySet oidc.KeySet) Option { + return func(o *Provider) error { + o.idTokenHinKeySet = keySet + return nil + } +} + func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { return func(o *Provider) error { o.idTokenHintVerifierOpts = opts @@ -509,12 +646,27 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option { } } +func WithCORSOptions(opts *cors.Options) Option { + return func(o *Provider) error { + o.corsOpts = opts + return nil + } +} + +// WithLogger lets a logger other than slog.Default(). +func WithLogger(logger *slog.Logger) Option { + return func(o *Provider) error { + o.logger = logger + return nil + } +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { issuerInterceptor := NewIssuerInterceptor(i) return func(handler http.Handler) http.Handler { for i := len(interceptors) - 1; i >= 0; i-- { handler = interceptors[i](handler) } - return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler)) + return issuerInterceptor.Handler(handler) } } diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go index ba3570b..e1ac0bd 100644 --- a/pkg/op/op_test.go +++ b/pkg/op/op_test.go @@ -11,24 +11,18 @@ import ( "testing" "time" + "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/oidc/v2/example/server/storage" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" "golang.org/x/text/language" ) -var testProvider op.OpenIDProvider - -const ( - testIssuer = "https://localhost:9998/" - pathLoggedOut = "/logged-out" -) - -func init() { - config := &op.Config{ +var ( + testProvider op.OpenIDProvider + testConfig = &op.Config{ CryptoKey: sha256.Sum256([]byte("test")), DefaultLogoutRedirectURI: pathLoggedOut, CodeMethodS256: true, @@ -36,28 +30,45 @@ func init() { AuthMethodPrivateKeyJWT: true, GrantTypeRefreshToken: true, RequestObjectSupported: true, + SupportedClaims: op.DefaultSupportedClaims, SupportedUILocales: []language.Tag{language.English}, DeviceAuthorization: op.DeviceAuthorizationConfig{ Lifetime: 5 * time.Minute, PollInterval: 5 * time.Second, - UserFormURL: testIssuer + "device", + UserFormPath: "/device", UserCode: op.UserCodeBase20, }, } +) +const ( + testIssuer = "https://localhost:9998/" + pathLoggedOut = "/logged-out" +) + +func init() { storage.RegisterClients( storage.NativeClient("native"), storage.WebClient("web", "secret", "https://example.com"), + storage.DeviceClient("device", "secret"), storage.WebClient("api", "secret"), ) - var err error - testProvider, err = op.NewOpenIDProvider(testIssuer, config, - storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(), + testProvider = newTestProvider(testConfig) +} + +func newTestProvider(config *op.Config) op.OpenIDProvider { + storage := storage.NewStorage(storage.NewUserStore(testIssuer)) + keySet := &op.OpenIDKeySet{storage} + provider, err := op.NewOpenIDProvider(testIssuer, config, storage, + op.WithAllowInsecure(), + op.WithAccessTokenKeySet(keySet), + op.WithIDTokenHintKeySet(keySet), ) if err != nil { panic(err) } + return provider } type routesTestStorage interface { @@ -151,7 +162,7 @@ func TestRoutes(t *testing.T) { values: map[string]string{ "client_id": client.GetID(), "redirect_uri": "https://example.com", - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "response_type": string(oidc.ResponseTypeCode), }, wantCode: http.StatusFound, @@ -170,7 +181,7 @@ func TestRoutes(t *testing.T) { }, }, { - // This call will fail. A successfull test is already + // This call will fail. A successful test is already // part of client/integration_test.go name: "code exchange", method: http.MethodGet, @@ -188,7 +199,7 @@ func TestRoutes(t *testing.T) { path: testProvider.TokenEndpoint().Relative(), values: map[string]string{ "grant_type": string(oidc.GrantTypeBearer), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "assertion": jwtToken, }, wantCode: http.StatusBadRequest, @@ -201,7 +212,7 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"web", "secret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeTokenExchange), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), "subject_token": jwtToken, "subject_token_type": string(oidc.AccessTokenType), }, @@ -218,13 +229,13 @@ func TestRoutes(t *testing.T) { basicAuth: &basicAuth{"sid1", "verysecret"}, values: map[string]string{ "grant_type": string(oidc.GrantTypeClientCredentials), - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, - contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`}, + contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`}, }, { - // This call will fail. A successfull test is already + // This call will fail. A successful test is already // part of device_test.go name: "device token", method: http.MethodPost, @@ -331,9 +342,9 @@ func TestRoutes(t *testing.T) { name: "device authorization", method: http.MethodGet, path: testProvider.DeviceAuthorizationEndpoint().Relative(), - basicAuth: &basicAuth{"web", "secret"}, + basicAuth: &basicAuth{"device", "secret"}, values: map[string]string{ - "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), }, wantCode: http.StatusOK, contains: []string{ @@ -365,7 +376,7 @@ func TestRoutes(t *testing.T) { } rec := httptest.NewRecorder() - testProvider.HttpHandler().ServeHTTP(rec, req) + testProvider.ServeHTTP(rec, req) resp := rec.Result() require.NoError(t, err) @@ -390,3 +401,54 @@ func TestRoutes(t *testing.T) { }) } } + +func TestWithCustomEndpoints(t *testing.T) { + type args struct { + auth *op.Endpoint + token *op.Endpoint + userInfo *op.Endpoint + revocation *op.Endpoint + endSession *op.Endpoint + keys *op.Endpoint + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "all nil", + args: args{}, + wantErr: op.ErrNilEndpoint, + }, + { + name: "all set", + args: args{ + auth: op.NewEndpoint("/authorize"), + token: op.NewEndpoint("/oauth/token"), + userInfo: op.NewEndpoint("/userinfo"), + revocation: op.NewEndpoint("/revoke"), + endSession: op.NewEndpoint("/end_session"), + keys: op.NewEndpoint("/keys"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := op.NewOpenIDProvider(testIssuer, testConfig, + storage.NewStorage(storage.NewUserStore(testIssuer)), + op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys), + ) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantErr != nil { + return + } + assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint()) + assert.Equal(t, tt.args.token, provider.TokenEndpoint()) + assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint()) + assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint()) + assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint()) + assert.Equal(t, tt.args.keys, provider.KeysEndpoint()) + }) + } +} diff --git a/pkg/op/probes.go b/pkg/op/probes.go index a56c92b..fa713da 100644 --- a/pkg/op/probes.go +++ b/pkg/op/probes.go @@ -5,7 +5,7 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" ) type ProbesFn func(context.Context) error @@ -41,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn { } func ok(w http.ResponseWriter) { - httphelper.MarshalJSON(w, status{"ok"}) + httphelper.MarshalJSON(w, Status{"ok"}) } -type status struct { +type Status struct { Status string `json:"status,omitempty"` } diff --git a/pkg/op/server.go b/pkg/op/server.go new file mode 100644 index 0000000..d45b734 --- /dev/null +++ b/pkg/op/server.go @@ -0,0 +1,350 @@ +package op + +import ( + "context" + "net/http" + "net/url" + + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/muhlemmer/gu" +) + +// Server describes the interface that needs to be implemented to serve +// OpenID Connect and Oauth2 standard requests. +// +// Methods are called after the HTTP route is resolved and +// the request body is parsed into the Request's Data field. +// When a method is called, it can be assumed that required fields, +// as described in their relevant standard, are validated already. +// The Response Data field may be of any type to allow flexibility +// to extend responses with custom fields. There are however requirements +// in the standards regarding the response models. Where applicable +// the method documentation gives a recommended type which can be used +// directly or extended upon. +// +// The addition of new methods is not considered a breaking change +// as defined by semver rules. +// Implementations MUST embed [UnimplementedServer] to maintain +// forward compatibility. +// +// EXPERIMENTAL: may change until v4 +type Server interface { + // Health returns a status of "ok" once the Server is listening. + // The recommended Response Data type is [Status]. + Health(context.Context, *Request[struct{}]) (*Response, error) + + // Ready returns a status of "ok" once all dependencies, + // such as database storage, are ready. + // An error can be returned to explain what is not ready. + // The recommended Response Data type is [Status]. + Ready(context.Context, *Request[struct{}]) (*Response, error) + + // Discovery returns the OpenID Provider Configuration Information for this server. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + // The recommended Response Data type is [oidc.DiscoveryConfiguration]. + Discovery(context.Context, *Request[struct{}]) (*Response, error) + + // Keys serves the JWK set which the client can use verify signatures from the op. + // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key. + // The recommended Response Data type is [jose.JSONWebKeySet]. + Keys(context.Context, *Request[struct{}]) (*Response, error) + + // VerifyAuthRequest verifies the Auth Request and + // adds the Client to the request. + // + // When the `request` field is populated with a + // "Request Object" JWT, it needs to be Validated + // and its claims overwrite any fields in the AuthRequest. + // If the implementation does not support "Request Object", + // it MUST return an [oidc.ErrRequestNotSupported]. + // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject + VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) + + // Authorize initiates the authorization flow and redirects to a login page. + // See the various https://openid.net/specs/openid-connect-core-1_0.html + // authorize endpoint sections (one for each type of flow). + Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error) + + // DeviceAuthorization initiates the device authorization flow. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1 + // The recommended Response Data type is [oidc.DeviceAuthorizationResponse]. + DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) + + // VerifyClient is called on most oauth/token handlers to authenticate, + // using either a secret (POST, Basic) or assertion (JWT). + // If no secrets are provided, the client must be public. + // This method is called before each method that takes a + // [ClientRequest] argument. + VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error) + + // CodeExchange returns Tokens after an authorization code + // is obtained in a successful Authorize flow. + // It is called by the Token endpoint handler when + // grant_type has the value authorization_code + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + // The recommended Response Data type is [oidc.AccessTokenResponse]. + CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) + + // RefreshToken returns new Tokens after verifying a Refresh token. + // It is called by the Token endpoint handler when + // grant_type has the value refresh_token + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens + // The recommended Response Data type is [oidc.AccessTokenResponse]. + RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) + + // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer + // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error) + + // TokenExchange handles the OAuth 2.0 token exchange grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange + // https://datatracker.ietf.org/doc/html/rfc8693 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) + + // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant + // It is called by the Token endpoint handler when + // grant_type has the value client_credentials + // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + // The recommended Response Data type is [oidc.AccessTokenResponse]. + ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) + + // DeviceToken handles the OAuth 2.0 Device Authorization Grant + // It is called by the Token endpoint handler when + // grant_type has the value urn:ietf:params:oauth:grant-type:device_code. + // It is typically called in a polling fashion and appropriate errors + // should be returned to signal authorization_pending or access_denied etc. + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4, + // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5. + // The recommended Response Data type is [oidc.AccessTokenResponse]. + DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) + + // Introspect handles the OAuth 2.0 Token Introspection endpoint. + // https://datatracker.ietf.org/doc/html/rfc7662 + // The recommended Response Data type is [oidc.IntrospectionResponse]. + Introspect(context.Context, *Request[IntrospectionRequest]) (*Response, error) + + // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User. + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + // The recommended Response Data type is [oidc.UserInfo]. + UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error) + + // Revocation handles token revocation using an access or refresh token. + // https://datatracker.ietf.org/doc/html/rfc7009 + // There are no response requirements. Data may remain empty. + Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error) + + // EndSession handles the OpenID Connect RP-Initiated Logout. + // https://openid.net/specs/openid-connect-rpinitiated-1_0.html + // There are no response requirements. Data may remain empty. + EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error) + + // mustImpl forces implementations to embed the UnimplementedServer for forward + // compatibility with the interface. + mustImpl() +} + +// Request contains the [http.Request] informational fields +// and parsed Data from the request body (POST) or URL parameters (GET). +// Data can be assumed to be validated according to the applicable +// standard for the specific endpoints. +// +// EXPERIMENTAL: may change until v4 +type Request[T any] struct { + Method string + URL *url.URL + Header http.Header + Form url.Values + PostForm url.Values + Data *T +} + +func (r *Request[_]) path() string { + return r.URL.Path +} + +func newRequest[T any](r *http.Request, data *T) *Request[T] { + return &Request[T]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + PostForm: r.PostForm, + Data: data, + } +} + +// ClientRequest is a Request with a verified client attached to it. +// Methods that receive this argument may assume the client was authenticated, +// or verified to be a public client. +// +// EXPERIMENTAL: may change until v4 +type ClientRequest[T any] struct { + *Request[T] + Client Client +} + +func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] { + return &ClientRequest[T]{ + Request: newRequest[T](r, data), + Client: client, + } +} + +// Response object for most [Server] methods. +// +// EXPERIMENTAL: may change until v4 +type Response struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + // Data will be JSON marshaled to + // the response body. + // We allow any type, so that implementations + // can extend the standard types as they wish. + // However, each method will recommend which + // (base) type to use as model, in order to + // be compliant with the standards. + Data any +} + +// NewResponse creates a new response for data, +// without custom headers. +func NewResponse(data any) *Response { + return &Response{ + Header: make(http.Header), + Data: data, + } +} + +func (resp *Response) writeOut(w http.ResponseWriter) { + gu.MapMerge(resp.Header, w.Header()) + httphelper.MarshalJSON(w, resp.Data) +} + +// Redirect is a special response type which will +// initiate a [http.StatusFound] redirect. +// The Params field will be encoded and set to the +// URL's RawQuery field before building the URL. +// +// EXPERIMENTAL: may change until v4 +type Redirect struct { + // Header map will be merged with the + // header on the [http.ResponseWriter]. + Header http.Header + + URL string +} + +func NewRedirect(url string) *Redirect { + return &Redirect{ + Header: make(http.Header), + URL: url, + } +} + +func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) { + gu.MapMerge(red.Header, w.Header()) + http.Redirect(w, r, red.URL, http.StatusFound) +} + +type UnimplementedServer struct{} + +// UnimplementedStatusCode is the status code returned for methods +// that are not yet implemented. +// Note that this means methods in the sense of the Go interface, +// and not http methods covered by "501 Not Implemented". +var UnimplementedStatusCode = http.StatusNotFound + +func unimplementedError(r interface{ path() string }) StatusError { + err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path()) + return NewStatusError(err, UnimplementedStatusCode) +} + +func unimplementedGrantError(gt oidc.GrantType) StatusError { + err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt) + return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 +} + +func (UnimplementedServer) mustImpl() {} + +func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if r.Data.RequestParam != "" { + return nil, oidc.ErrRequestNotSupported() + } + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeCode) +} + +func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) +} + +func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) +} + +func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) +} + +func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) +} + +func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) +} + +func (UnimplementedServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + return nil, unimplementedError(r) +} + +func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + return nil, unimplementedError(r) +} diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go new file mode 100644 index 0000000..d71a354 --- /dev/null +++ b/pkg/op/server_http.go @@ -0,0 +1,524 @@ +package op + +import ( + "context" + "log/slog" + "net/http" + "net/url" + + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/go-chi/chi/v5" + "github.com/rs/cors" + "github.com/zitadel/logging" + "github.com/zitadel/schema" +) + +// RegisterServer registers an implementation of Server. +// The resulting handler takes care of routing and request parsing, +// with some basic validation of required fields. +// The routes can be customized with [WithEndpoints]. +// +// EXPERIMENTAL: may change until v4 +func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + + ws := &webServer{ + router: chi.NewRouter(), + server: server, + endpoints: endpoints, + decoder: decoder, + corsOpts: &defaultCORSOptions, + logger: slog.Default(), + } + + for _, option := range options { + option(ws) + } + + ws.createRouter() + ws.handler = ws.router + if ws.corsOpts != nil { + ws.handler = cors.New(*ws.corsOpts).Handler(ws.router) + } + return ws +} + +type ServerOption func(s *webServer) + +// WithHTTPMiddleware sets the passed middleware chain to the root of +// the Server's router. +func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption { + return func(s *webServer) { + s.router.Use(m...) + } +} + +// WithSetRouter allows customization or the Server's router. +func WithSetRouter(set func(chi.Router)) ServerOption { + return func(s *webServer) { + set(s.router) + } +} + +// WithDecoder overrides the default decoder, +// which is a [schema.Decoder] with IgnoreUnknownKeys set to true. +func WithDecoder(decoder httphelper.Decoder) ServerOption { + return func(s *webServer) { + s.decoder = decoder + } +} + +// WithServerCORSOptions sets the CORS policy for the Server's router. +func WithServerCORSOptions(opts *cors.Options) ServerOption { + return func(s *webServer) { + s.corsOpts = opts + } +} + +// WithFallbackLogger overrides the fallback logger, which +// is used when no logger was found in the context. +// Defaults to [slog.Default]. +func WithFallbackLogger(logger *slog.Logger) ServerOption { + return func(s *webServer) { + s.logger = logger + } +} + +type webServer struct { + server Server + router *chi.Mux + handler http.Handler + endpoints Endpoints + decoder httphelper.Decoder + corsOpts *cors.Options + logger *slog.Logger +} + +func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) +} + +func (s *webServer) getLogger(ctx context.Context) *slog.Logger { + if logger, ok := logging.FromContext(ctx); ok { + return logger + } + return s.logger +} + +func (s *webServer) createRouter() { + s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health)) + s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready)) + s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery)) + + s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler) + s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler)) + s.endpointRoute(s.endpoints.Token, s.tokensHandler) + s.endpointRoute(s.endpoints.Introspection, s.introspectionHandler) + s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler) + s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler)) + s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler) + s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys)) +} + +func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) { + if e != nil { + traceHandler := func(w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), e.Relative()) + r = r.WithContext(ctx) + hf(w, r) + defer span.End() + } + s.router.HandleFunc(e.Relative(), traceHandler) + s.logger.Info("registered route", "endpoint", e.Relative()) + } +} + +type clientHandler func(w http.ResponseWriter, r *http.Request, client Client) + +func (s *webServer) withClient(handler clientHandler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx, span := tracer.Start(r.Context(), r.URL.Path) + defer span.End() + r = r.WithContext(ctx) + + client, err := s.verifyRequestClient(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" { + if !ValidateGrantType(client, grantType) { + WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context())) + return + } + } + handler(w, r, client) + } +} + +func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) { + cc, err := s.parseClientCredentials(r) + if err != nil { + return nil, err + } + return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{ + Method: r.Method, + URL: r.URL, + Header: r.Header, + Form: r.Form, + Data: cc, + }) +} + +func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredentials, err error) { + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + cc := new(ClientCredentials) + if err = s.decoder.Decode(cc, r.Form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + // Basic auth takes precedence, so if set it overwrites the form data. + if clientID, clientSecret, ok := r.BasicAuth(); ok { + cc.ClientID, err = url.QueryUnescape(clientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + cc.ClientSecret, err = url.QueryUnescape(clientSecret) + if err != nil { + return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err) + } + } + if cc.ClientID == "" && cc.ClientAssertion == "" { + return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided") + } + if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion { + return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType) + } + return cc, nil +} + +func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect, err := s.authorize(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + redirect.writeOut(w, r) +} + +func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) { + cr, err := s.server.VerifyAuthRequest(ctx, r) + if err != nil { + return nil, err + } + authReq := cr.Data + if authReq.RedirectURI == "" { + return nil, ErrAuthReqMissingRedirectURI + } + authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge) + if err != nil { + return nil, err + } + authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes) + if err != nil { + return nil, err + } + if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil { + return nil, err + } + if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil { + return nil, err + } + return s.server.Authorize(ctx, cr) +} + +func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + + switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType { + case oidc.GrantTypeCode: + s.withClient(s.codeExchangeHandler)(w, r) + case oidc.GrantTypeRefreshToken: + s.withClient(s.refreshTokenHandler)(w, r) + case oidc.GrantTypeClientCredentials: + s.withClient(s.clientCredentialsHandler)(w, r) + case oidc.GrantTypeBearer: + s.jwtProfileHandler(w, r) + case oidc.GrantTypeTokenExchange: + s.withClient(s.tokenExchangeHandler)(w, r) + case oidc.GrantTypeDeviceCode: + s.withClient(s.deviceTokenHandler)(w, r) + case "": + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context())) + default: + WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context())) + } +} + +func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Assertion == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Code == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context())) + return + } + if request.RedirectURI == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.RefreshToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.SubjectToken == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context())) + return + } + if request.SubjectTokenType == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context())) + return + } + if !request.SubjectTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context())) + return + } + if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context())) + return + } + resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) { + if client.AuthMethod() == oidc.AuthMethodNone { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + + request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.DeviceCode == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) { + cc, err := s.parseClientCredentials(r) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if cc.ClientSecret == "" && cc.ClientAssertion == "" { + WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context())) + return + } + request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Introspect(r.Context(), newRequest(r, &IntrospectionRequest{cc, request})) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if token, err := getAccessToken(r); err == nil { + request.AccessToken = token + } + if request.AccessToken == "" { + err = NewStatusError( + oidc.ErrInvalidRequest().WithDescription("access token missing"), + http.StatusUnauthorized, + ) + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.UserInfo(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) { + request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + if request.Token == "" { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context())) + return + } + resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) +} + +func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) { + request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp, err := s.server.EndSession(r.Context(), newRequest(r, request)) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w, r) +} + +func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context())) + return + } + resp, err := method(r.Context(), newRequest(r, &struct{}{})) + if err != nil { + WriteError(w, r, err, s.getLogger(r.Context())) + return + } + resp.writeOut(w) + } +} + +func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) { + dst := new(R) + if err := r.ParseForm(); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) + } + form := r.Form + if postOnly { + form = r.PostForm + } + if err := decoder.Decode(dst, form); err != nil { + return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err) + } + return dst, nil +} diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go new file mode 100644 index 0000000..02200ee --- /dev/null +++ b/pkg/op/server_http_routes_test.go @@ -0,0 +1,345 @@ +package op_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" +) + +func jwtProfile() (string, error) { + keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json") + if err != nil { + return "", err + } + signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID) + if err != nil { + return "", err + } + return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer) +} + +func TestServerRoutes(t *testing.T) { + server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints), op.AuthorizeCallbackHandler(testProvider)) + + storage := testProvider.Storage().(routesTestStorage) + ctx := op.ContextWithIssuer(context.Background(), testIssuer) + + client, err := storage.GetClientByClientID(ctx, "web") + require.NoError(t, err) + + oidcAuthReq := &oidc.AuthRequest{ + ClientID: client.GetID(), + RedirectURI: "https://example.com", + MaxAge: gu.Ptr[uint](300), + Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone}, + ResponseType: oidc.ResponseTypeCode, + } + + authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1") + require.NoError(t, err) + storage.AuthRequestDone(authReq.GetID()) + + accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "") + require.NoError(t, err) + idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client) + require.NoError(t, err) + jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "") + require.NoError(t, err) + jwtProfileToken, err := jwtProfile() + require.NoError(t, err) + + oidcAuthReq.IDTokenHint = idToken + + serverURL, err := url.Parse(testIssuer) + require.NoError(t, err) + + type basicAuth struct { + username, password string + } + + tests := []struct { + name string + method string + path string + basicAuth *basicAuth + header map[string]string + values map[string]string + body map[string]string + wantCode int + headerContains map[string]string + json string // test for exact json output + contains []string // when the body output is not constant, we just check for snippets to be present in the response + }{ + { + name: "health", + method: http.MethodGet, + path: "/healthz", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "ready", + method: http.MethodGet, + path: "/ready", + wantCode: http.StatusOK, + json: `{"status":"ok"}`, + }, + { + name: "discovery", + method: http.MethodGet, + path: oidc.DiscoveryEndpoint, + wantCode: http.StatusOK, + json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`, + }, + { + name: "authorization", + method: http.MethodGet, + path: testProvider.AuthorizationEndpoint().Relative(), + values: map[string]string{ + "client_id": client.GetID(), + "redirect_uri": "https://example.com", + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "response_type": string(oidc.ResponseTypeCode), + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/login/username?authRequestID="}, + }, + { + // This call will fail. A successfull test is already + // part of client/integration_test.go + name: "code exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeCode), + "client_id": client.GetID(), + "client_secret": "secret", + "redirect_uri": "https://example.com", + "code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_grant", "error_description":"invalid code"}`, + }, + { + name: "JWT authorization", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeBearer), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "assertion": jwtProfileToken, + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299,"scope":"openid"}`}, + }, + { + name: "Token exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeTokenExchange), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + "subject_token": jwtToken, + "subject_token_type": string(oidc.AccessTokenType), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`, + }, + }, + { + name: "Client credentials exchange", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"sid1", "verysecret"}, + values: map[string]string{ + "grant_type": string(oidc.GrantTypeClientCredentials), + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`}, + }, + { + // This call will fail. A successful test is already + // part of device_test.go + name: "device token", + method: http.MethodPost, + path: testProvider.TokenEndpoint().Relative(), + basicAuth: &basicAuth{"device", "secret"}, + header: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + body: map[string]string{ + "grant_type": string(oidc.GrantTypeDeviceCode), + "device_code": "123", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"access_denied","error_description":"The authorization request was denied."}`, + }, + { + name: "missing grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + wantCode: http.StatusBadRequest, + json: `{"error":"invalid_request","error_description":"grant_type missing"}`, + }, + { + name: "unsupported grant type", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": "foo", + }, + wantCode: http.StatusBadRequest, + json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`, + }, + { + name: "introspection", + method: http.MethodGet, + path: testProvider.IntrospectionEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessToken, + }, + wantCode: http.StatusOK, + json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "user info", + method: http.MethodGet, + path: testProvider.UserinfoEndpoint().Relative(), + header: map[string]string{ + "authorization": "Bearer " + accessToken, + }, + wantCode: http.StatusOK, + json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`, + }, + { + name: "refresh token", + method: http.MethodGet, + path: testProvider.TokenEndpoint().Relative(), + values: map[string]string{ + "grant_type": string(oidc.GrantTypeRefreshToken), + "refresh_token": refreshToken, + "client_id": client.GetID(), + "client_secret": "secret", + }, + wantCode: http.StatusOK, + contains: []string{ + `{"access_token":"`, + `","token_type":"Bearer","refresh_token":"`, + `","expires_in":299,"id_token":"`, + }, + }, + { + name: "revoke", + method: http.MethodGet, + path: testProvider.RevocationEndpoint().Relative(), + basicAuth: &basicAuth{"web", "secret"}, + values: map[string]string{ + "token": accessTokenRevoke, + }, + wantCode: http.StatusOK, + }, + { + name: "end session", + method: http.MethodGet, + path: testProvider.EndSessionEndpoint().Relative(), + values: map[string]string{ + "id_token_hint": idToken, + "client_id": "web", + }, + wantCode: http.StatusFound, + headerContains: map[string]string{"Location": "/logged-out"}, + contains: []string{`Found.`}, + }, + { + name: "keys", + method: http.MethodGet, + path: testProvider.KeysEndpoint().Relative(), + wantCode: http.StatusOK, + contains: []string{ + `{"keys":[{"use":"sig","kty":"RSA","kid":"`, + `","alg":"RS256","n":"`, `","e":"AQAB"}]}`, + }, + }, + { + name: "device authorization", + method: http.MethodGet, + path: testProvider.DeviceAuthorizationEndpoint().Relative(), + basicAuth: &basicAuth{"device", "secret"}, + values: map[string]string{ + "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(), + }, + wantCode: http.StatusOK, + contains: []string{ + `{"device_code":"`, `","user_code":"`, + `","verification_uri":"https://localhost:9998/device"`, + `"verification_uri_complete":"https://localhost:9998/device?user_code=`, + `","expires_in":300,"interval":5}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := gu.PtrCopy(serverURL) + u.Path = tt.path + if tt.values != nil { + u.RawQuery = mapAsValues(tt.values) + } + var body io.Reader + if tt.body != nil { + body = strings.NewReader(mapAsValues(tt.body)) + } + + req := httptest.NewRequest(tt.method, u.String(), body) + for k, v := range tt.header { + req.Header.Set(k, v) + } + if tt.basicAuth != nil { + req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password) + } + + rec := httptest.NewRecorder() + server.ServeHTTP(rec, req) + + resp := rec.Result() + require.NoError(t, err) + assert.Equal(t, tt.wantCode, resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + respBodyString := string(respBody) + t.Log(respBodyString) + t.Log(resp.Header) + + if tt.json != "" { + assert.JSONEq(t, tt.json, respBodyString) + } + for _, c := range tt.contains { + assert.Contains(t, respBodyString, c) + } + for k, v := range tt.headerContains { + assert.Contains(t, resp.Header.Get(k), v) + } + }) + } +} diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go new file mode 100644 index 0000000..75d02ca --- /dev/null +++ b/pkg/op/server_http_test.go @@ -0,0 +1,1328 @@ +package op + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/schema" +) + +func TestRegisterServer(t *testing.T) { + server := UnimplementedServer{} + endpoints := Endpoints{ + Authorization: &Endpoint{ + path: "/auth", + }, + } + decoder := schema.NewDecoder() + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + + h := RegisterServer(server, endpoints, + WithDecoder(decoder), + WithFallbackLogger(logger), + ) + got := h.(*webServer) + assert.Equal(t, got.server, server) + assert.Equal(t, got.endpoints, endpoints) + assert.Equal(t, got.decoder, decoder) + assert.Equal(t, got.logger, logger) +} + +type testClient struct { + id string + appType ApplicationType + authMethod oidc.AuthMethod + accessTokenType AccessTokenType + responseTypes []oidc.ResponseType + grantTypes []oidc.GrantType + devMode bool +} + +type clientType string + +const ( + clientTypeWeb clientType = "web" + clientTypeNative clientType = "native" + clientTypeUserAgent clientType = "useragent" +) + +func newClient(kind clientType) *testClient { + client := &testClient{ + id: string(kind), + } + + switch kind { + case clientTypeWeb: + client.appType = ApplicationTypeWeb + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeNative: + client.appType = ApplicationTypeNative + client.authMethod = oidc.AuthMethodNone + client.accessTokenType = AccessTokenTypeBearer + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode} + case clientTypeUserAgent: + client.appType = ApplicationTypeUserAgent + client.authMethod = oidc.AuthMethodBasic + client.accessTokenType = AccessTokenTypeJWT + client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken} + default: + panic(fmt.Errorf("invalid client type %s", kind)) + } + return client +} + +func (c *testClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback", + } +} + +func (c *testClient) PostLogoutRedirectURIs() []string { + return []string{} +} + +func (c *testClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *testClient) ApplicationType() ApplicationType { + return c.appType +} + +func (c *testClient) AuthMethod() oidc.AuthMethod { + return c.authMethod +} + +func (c *testClient) GetID() string { + return c.id +} + +func (c *testClient) AccessTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) IDTokenLifetime() time.Duration { + return 5 * time.Minute +} + +func (c *testClient) AccessTokenType() AccessTokenType { + return c.accessTokenType +} + +func (c *testClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +func (c *testClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +func (c *testClient) DevMode() bool { + return c.devMode +} + +func (c *testClient) AllowedScopes() []string { + return nil +} + +func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +func (c *testClient) IsScopeAllowed(scope string) bool { + return false +} + +func (c *testClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +func (c *testClient) ClockSkew() time.Duration { + return 0 +} + +type requestVerifier struct { + UnimplementedServer + client Client +} + +func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: s.client, + }, nil +} + +func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + if s.client == nil { + return nil, oidc.ErrServerError() + } + return s.client, nil +} + +var testDecoder = func() *schema.Decoder { + decoder := schema.NewDecoder() + decoder.IgnoreUnknownKeys(true) + return decoder +}() + +type webServerResult struct { + wantStatus int + wantBody string +} + +func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) { + t.Helper() + if r.Method == http.MethodPost { + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + w := httptest.NewRecorder() + handler(w, r) + res := w.Result() + assert.Equal(t, want.wantStatus, res.StatusCode) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.JSONEq(t, want.wantBody, string(body)) +} + +func Test_webServer_withClient(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`, + }, + }, + { + name: "no grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")), + want: webServerResult{ + wantStatus: http.StatusOK, + wantBody: `{"foo":"bar"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + logger: slog.Default(), + } + handler := func(w http.ResponseWriter, r *http.Request, client Client) { + fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo")) + } + runWebServerTest(t, s.withClient(handler), tt.r, tt.want) + }) + } +} + +func Test_webServer_verifyRequestClient(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want Client + wantErr error + }{ + { + name: "parse form error", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "basic auth, client_id error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth(`%%%`, "secret") + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "basic auth, client_secret error", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) + r.SetBasicAuth("web", `%%%`) + return r + }(), + wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"), + }, + { + name: "missing client id and assertion", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"), + }, + { + name: "wrong assertion type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")), + wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"), + }, + { + name: "unimplemented verify client called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")), + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + got, err := s.verifyRequestClient(tt.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_authorizeHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "server error", + fields: fields{ + server: &requestVerifier{}, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusInternalServerError, + wantBody: `{"error":"server_error"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.authorizeHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_authorize(t *testing.T) { + type args struct { + ctx context.Context + r *Request[oidc.AuthRequest] + } + tests := []struct { + name string + server Server + args args + want *Redirect + wantErr error + }{ + { + name: "verify error", + server: &requestVerifier{}, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: oidc.ErrServerError(), + }, + { + name: "missing redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + MaxAge: gu.Ptr[uint](300), + }, + }, + }, + wantErr: ErrAuthReqMissingRedirectURI, + }, + { + name: "invalid prompt", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone, oidc.PromptLogin}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"), + }, + { + name: "missing scopes", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequest(). + WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid redirect", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://example.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrInvalidRequestRedirectURI(). + WithDescription("The requested redirect_uri is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "invalid response type", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeIDToken, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " + + "If you have any questions, you may contact the administrator of the application."), + }, + { + name: "unimplemented Authorize called", + server: &requestVerifier{ + client: newClient(clientTypeWeb), + }, + args: args{ + ctx: context.Background(), + r: &Request[oidc.AuthRequest]{ + URL: &url.URL{ + Path: "/authorize", + }, + Data: &oidc.AuthRequest{ + Scopes: oidc.SpaceDelimitedArray{"openid"}, + ResponseType: oidc.ResponseTypeCode, + ClientID: "web", + RedirectURI: "https://registered.com/callback", + MaxAge: gu.Ptr[uint](300), + Prompt: []string{oidc.PromptNone}, + }, + }, + }, + wantErr: StatusError{ + parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"), + statusCode: UnimplementedStatusCode, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.server, + decoder: testDecoder, + logger: slog.Default(), + } + got, err := s.authorize(tt.args.ctx, tt.args.r) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_webServer_deviceAuthorizationHandler(t *testing.T) { + type fields struct { + server Server + decoder httphelper.Decoder + } + tests := []struct { + name string + fields fields + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + fields: fields{ + server: &requestVerifier{}, + decoder: schema.NewDecoder(), + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented DeviceAuthorization called", + fields: fields{ + server: &requestVerifier{ + client: newClient(clientTypeNative), + }, + decoder: testDecoder, + }, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: tt.fields.server, + decoder: tt.fields.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokensHandler(t *testing.T) { + tests := []struct { + name string + r *http.Request + want webServerResult + }{ + { + name: "parse form error", + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "missing grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`, + }, + }, + { + name: "invalid grant type", + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + logger: slog.Default(), + } + runWebServerTest(t, s.tokensHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_jwtProfileHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "assertion missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`, + }, + }, + { + name: "unimplemented JWTProfile called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want) + }) + } +} + +func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) { + t.Helper() + runWebServerTest(t, func(client Client) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + handler(w, r, client) + } + }(client), r, want) +} + +func Test_webServer_codeExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"code missing"}`, + }, + }, + { + name: "redirect missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`, + }, + }, + { + name: "unimplemented CodeExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_refreshTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "refresh token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`, + }, + }, + { + name: "unimplemented RefreshToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_tokenExchangeHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "subject token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`, + }, + }, + { + name: "subject token type missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`, + }, + }, + { + name: "subject token type unsupported", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`, + }, + }, + { + name: "unsupported requested token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`, + }, + }, + { + name: "unsupported actor token type", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`, + }, + }, + { + name: "unimplemented TokenExchange called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_clientCredentialsHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "unimplemented ClientCredentialsExchange called", + decoder: testDecoder, + client: newClient(clientTypeUserAgent), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_deviceTokenHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "device code missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`, + }, + }, + { + name: "unimplemented DeviceToken called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + client := newClient(clientTypeUserAgent) + runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want) + }) + } +} + +func Test_webServer_introspectionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "public client", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Introspect called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET&token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.introspectionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_userInfoHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "access token missing", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusUnauthorized, + wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`, + }, + }, + { + name: "unimplemented UserInfo called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "bearer", + decoder: testDecoder, + r: func() *http.Request { + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " ")) + return r + }(), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.userInfoHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_revocationHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + client Client + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "token missing", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"token missing"}`, + }, + }, + { + name: "unimplemented Revocation called, confidential client", + decoder: testDecoder, + client: newClient(clientTypeWeb), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + { + name: "unimplemented Revocation called, public client", + decoder: testDecoder, + client: newClient(clientTypeNative), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want) + }) + } +} + +func Test_webServer_endSessionHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + r *http.Request + want webServerResult + }{ + { + name: "decoder error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`, + }, + }, + { + name: "unimplemented EndSession called", + decoder: testDecoder, + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")), + want: webServerResult{ + wantStatus: UnimplementedStatusCode, + wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, s.endSessionHandler, tt.r, tt.want) + }) + } +} + +func Test_webServer_simpleHandler(t *testing.T) { + tests := []struct { + name string + decoder httphelper.Decoder + method func(context.Context, *Request[struct{}]) (*Response, error) + r *http.Request + want webServerResult + }{ + { + name: "parse error", + decoder: schema.NewDecoder(), + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusBadRequest, + wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`, + }, + }, + { + name: "method error", + decoder: schema.NewDecoder(), + method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) { + return nil, io.ErrClosedPipe + }, + r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))), + want: webServerResult{ + wantStatus: http.StatusInternalServerError, + wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &webServer{ + server: UnimplementedServer{}, + decoder: tt.decoder, + logger: slog.Default(), + } + runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want) + }) + } +} + +func Test_decodeRequest(t *testing.T) { + type dst struct { + A string `schema:"a"` + B string `schema:"b"` + } + type args struct { + r *http.Request + postOnly bool + } + tests := []struct { + name string + args args + want *dst + wantErr error + }{ + { + name: "parse error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"), + }, + { + name: "decode error", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")), + }, + wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"), + }, + { + name: "success, get", + args: args{ + r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil), + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + { + name: "success, post only", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: true, + }, + want: &dst{ + A: "b", + }, + }, + { + name: "success, post mixed", + args: args{ + r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")), + postOnly: false, + }, + want: &dst{ + A: "b", + B: "a", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.r.Method == http.MethodPost { + tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go new file mode 100644 index 0000000..06e4e93 --- /dev/null +++ b/pkg/op/server_legacy.go @@ -0,0 +1,457 @@ +package op + +import ( + "context" + "errors" + "net/http" + "time" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "github.com/go-chi/chi/v5" +) + +// ExtendedLegacyServer allows embedding [LegacyServer] in a struct, +// so that its methods can be individually overridden. +// +// EXPERIMENTAL: may change until v4 +type ExtendedLegacyServer interface { + Server + Provider() OpenIDProvider + Endpoints() Endpoints + AuthCallbackURL() func(context.Context, string) string +} + +// RegisterLegacyServer registers a [LegacyServer] or an extension thereof. +// It takes care of registering the IssuerFromRequest middleware. +// The authorizeCallbackHandler is registered on `/callback` under the authorization endpoint. +// Neither are part of the bare [Server] interface. +// +// EXPERIMENTAL: may change until v4 +func RegisterLegacyServer(s ExtendedLegacyServer, authorizeCallbackHandler http.HandlerFunc, options ...ServerOption) http.Handler { + options = append(options, + WithHTTPMiddleware(intercept(s.Provider().IssuerFromRequest)), + WithSetRouter(func(r chi.Router) { + r.HandleFunc(s.Endpoints().Authorization.Relative()+authCallbackPathSuffix, authorizeCallbackHandler) + }), + ) + return RegisterServer(s, s.Endpoints(), options...) +} + +// LegacyServer is an implementation of [Server] that +// simply wraps an [OpenIDProvider]. +// It can be used to transition from the former Provider/Storage +// interfaces to the new Server interface. +// +// EXPERIMENTAL: may change until v4 +type LegacyServer struct { + UnimplementedServer + provider OpenIDProvider + endpoints Endpoints +} + +// NewLegacyServer wraps provider in a `Server` implementation +// +// Only non-nil endpoints will be registered on the router. +// Nil endpoints are disabled. +// +// The passed endpoints is also used for the discovery config, +// and endpoints already set to the provider are ignored. +// Any `With*Endpoint()` option used on the provider is +// therefore ineffective. +// +// EXPERIMENTAL: may change until v4 +func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer { + return &LegacyServer{ + provider: provider, + endpoints: endpoints, + } +} + +func (s *LegacyServer) Provider() OpenIDProvider { + return s.provider +} + +func (s *LegacyServer) Endpoints() Endpoints { + return s.endpoints +} + +// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login +func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string { + return func(ctx context.Context, requestID string) string { + ctx, span := tracer.Start(ctx, "LegacyServer.AuthCallbackURL") + defer span.End() + + return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID + } +} + +func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) { + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) { + for _, probe := range s.provider.Probes() { + // shouldn't we run probes in Go routines? + if err := probe(ctx); err != nil { + return nil, AsStatusError(err, http.StatusInternalServerError) + } + } + return NewResponse(Status{Status: "ok"}), nil +} + +func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Discovery") + defer span.End() + + return NewResponse( + createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints), + ), nil +} + +func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Keys") + defer span.End() + + keys, err := s.provider.Storage().KeySet(ctx) + if err != nil { + return nil, AsStatusError(err, http.StatusInternalServerError) + } + return NewResponse(jsonWebKeySet(keys)), nil +} + +var ( + ErrAuthReqMissingClientID = errors.New("auth request is missing client_id") + ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri") +) + +func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) { + ctx, span := tracer.Start(ctx, "LegacyServer.VerifyAuthRequest") + defer span.End() + + if r.Data.RequestParam != "" { + if !s.provider.RequestObjectSupported() { + return nil, oidc.ErrRequestNotSupported() + } + err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx)) + if err != nil { + return nil, err + } + } + if r.Data.ClientID == "" { + return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error()) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id") + } + + return &ClientRequest[oidc.AuthRequest]{ + Request: r, + Client: client, + }, nil +} + +func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Authorize") + defer span.End() + + userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx)) + if err != nil { + return nil, err + } + req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID) + if err != nil { + return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger()) + } + return NewRedirect(r.Client.LoginURL(req.GetID())), nil +} + +func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.DeviceAuthorization") + defer span.End() + + response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider) + if err != nil { + return nil, AsStatusError(err, http.StatusInternalServerError) + } + return NewResponse(response), nil +} + +func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.VerifyClient") + defer span.End() + + if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") + } + return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret) + } + + if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { + jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported") + } + return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger) + } + client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID) + if err != nil { + return nil, oidc.ErrInvalidClient().WithParent(err) + } + + switch client.AuthMethod() { + case oidc.AuthMethodNone: + return client, nil + case oidc.AuthMethodPrivateKeyJWT: + return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client") + case oidc.AuthMethodPost: + if !s.provider.AuthMethodPostSupported() { + return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported") + } + } + + err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage()) + if err != nil { + return nil, err + } + + return client, nil +} + +func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.CodeExchange") + defer span.End() + + authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code) + if err != nil { + return nil, err + } + if r.Client.AuthMethod() == oidc.AuthMethodNone || r.Data.CodeVerifier != "" { + if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil { + return nil, err + } + } + if r.Data.RedirectURI != authReq.GetRedirectURI() { + return nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond") + } + resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "") + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.RefreshToken") + defer span.End() + + if !s.provider.GrantTypeRefreshTokenSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken) + } + request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken) + if err != nil { + return nil, err + } + if r.Client.GetID() != request.GetClientID() { + return nil, oidc.ErrInvalidGrant() + } + if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil { + return nil, err + } + resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.JWTProfile") + defer span.End() + + exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeBearer) + } + tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx)) + if err != nil { + return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid") + } + + tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.TokenExchange") + defer span.End() + + if !s.provider.GrantTypeTokenExchangeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.ClientCredentialsExchange") + defer span.End() + + storage, ok := s.provider.Storage().(ClientCredentialsStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials) + } + tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope) + if err != nil { + return nil, err + } + resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.DeviceToken") + defer span.End() + + if !s.provider.GrantTypeDeviceCodeSupported() { + return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode) + } + // use a limited context timeout shorter as the default + // poll interval of 5 seconds. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + + tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider) + if err != nil { + return nil, err + } + resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client) + if err != nil { + return nil, err + } + return NewResponse(resp), nil +} + +func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.authenticateResourceClient") + defer span.End() + + if cc.ClientAssertion != "" { + if jp, ok := s.provider.(ClientJWTProfile); ok { + return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp) + } + return "", oidc.ErrInvalidClient().WithDescription("client_assertion not supported") + } + if err := s.provider.Storage().AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil { + return "", oidc.ErrUnauthorizedClient().WithParent(err) + } + return cc.ClientID, nil +} + +func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Introspect") + defer span.End() + + clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials) + if err != nil { + return nil, err + } + response := new(oidc.IntrospectionResponse) + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token) + if !ok { + return NewResponse(response), nil + } + err = s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID) + if err != nil { + return NewResponse(response), nil + } + response.Active = true + return NewResponse(response), nil +} + +func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.UserInfo") + defer span.End() + + tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken) + if !ok { + return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized) + } + info := new(oidc.UserInfo) + err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin")) + if err != nil { + return nil, NewStatusError(err, http.StatusForbidden) + } + return NewResponse(info), nil +} + +func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.Revocation") + defer span.End() + + var subject string + doDecrypt := true + if r.Data.TokenTypeHint != "access_token" { + userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token) + if err != nil { + // An invalid refresh token means that we'll try other things (leaving doDecrypt==true) + if !errors.Is(err, ErrInvalidRefreshToken) { + return nil, RevocationError(oidc.ErrServerError().WithParent(err)) + } + } else { + r.Data.Token = tokenID + subject = userID + doDecrypt = false + } + } + if doDecrypt { + tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token) + if ok { + r.Data.Token = tokenID + subject = userID + } + } + if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil { + return nil, RevocationError(err) + } + return NewResponse(nil), nil +} + +func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) { + ctx, span := tracer.Start(ctx, "LegacyServer.EndSession") + defer span.End() + + session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider) + if err != nil { + return nil, err + } + redirect := session.RedirectURI + if fromRequest, ok := s.provider.Storage().(CanTerminateSessionFromRequest); ok { + redirect, err = fromRequest.TerminateSessionFromRequest(ctx, session) + } else { + err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID) + } + if err != nil { + return nil, err + } + return NewRedirect(redirect), nil +} diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go new file mode 100644 index 0000000..0cad8fd --- /dev/null +++ b/pkg/op/server_test.go @@ -0,0 +1,5 @@ +package op + +// implementation check +var _ Server = &UnimplementedServer{} +var _ Server = &LegacyServer{} diff --git a/pkg/op/session.go b/pkg/op/session.go index c4f76f3..ac663c9 100644 --- a/pkg/op/session.go +++ b/pkg/op/session.go @@ -2,19 +2,22 @@ package op import ( "context" + "errors" + "log/slog" "net/http" "net/url" "path" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type SessionEnder interface { Decoder() httphelper.Decoder Storage() Storage - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier DefaultLogoutRedirectURI() string + Logger() *slog.Logger } func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) { @@ -24,6 +27,10 @@ func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Reque } func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { + ctx, span := tracer.Start(r.Context(), "EndSession") + defer span.End() + r = r.WithContext(ctx) + req, err := ParseEndSessionRequest(r, ender.Decoder()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -31,15 +38,20 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) { } session, err := ValidateEndSessionRequest(r.Context(), req, ender) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, ender.Logger()) return } - err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) + redirect := session.RedirectURI + if fromRequest, ok := ender.Storage().(CanTerminateSessionFromRequest); ok { + redirect, err = fromRequest.TerminateSessionFromRequest(r.Context(), session) + } else { + err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID) + } if err != nil { - RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session")) + RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger()) return } - http.Redirect(w, r, session.RedirectURI, http.StatusFound) + http.Redirect(w, r, redirect, http.StatusFound) } func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) { @@ -56,15 +68,21 @@ func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc. } func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) { + ctx, span := tracer.Start(ctx, "ValidateEndSessionRequest") + defer span.End() + session := &EndSessionRequest{ RedirectURI: ender.DefaultLogoutRedirectURI(), + LogoutHint: req.LogoutHint, + UILocales: req.UILocales, } if req.IdTokenHint != "" { - claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx)) - if err != nil { + claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx)) + if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) { return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err) } session.UserID = claims.GetSubject() + session.IDTokenHintClaims = claims if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() { return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint") } diff --git a/pkg/op/signer.go b/pkg/op/signer.go index 22ef8ca..5c3dd6a 100644 --- a/pkg/op/signer.go +++ b/pkg/op/signer.go @@ -3,16 +3,14 @@ package op import ( "errors" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" ) -var ( - ErrSignerCreationFailed = errors.New("signer creation failed") -) +var ErrSignerCreationFailed = errors.New("signer creation failed") type SigningKey interface { SignatureAlgorithm() jose.SignatureAlgorithm - Key() interface{} + Key() any ID() string } @@ -23,9 +21,9 @@ func SignerFromKey(key SigningKey) (jose.Signer, error) { Key: key.Key(), KeyID: key.ID(), }, - }, &jose.SignerOptions{}) + }, (&jose.SignerOptions{}).WithType("JWT")) if err != nil { - return nil, ErrSignerCreationFailed //TODO: log / wrap error? + return nil, ErrSignerCreationFailed // TODO: log / wrap error? } return signer, nil } @@ -34,5 +32,5 @@ type Key interface { ID() string Algorithm() jose.SignatureAlgorithm Use() string - Key() interface{} + Key() any } diff --git a/pkg/op/storage.go b/pkg/op/storage.go index 590c4a0..2dbd124 100644 --- a/pkg/op/storage.go +++ b/pkg/op/storage.go @@ -5,9 +5,10 @@ import ( "errors" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" + "golang.org/x/text/language" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type AuthStorage interface { @@ -62,6 +63,14 @@ type AuthStorage interface { KeySet(context.Context) ([]Key, error) } +// CanTerminateSessionFromRequest is an optional additional interface that may be implemented by +// implementors of Storage as an alternative to TerminateSession of the AuthStorage. +// It passes the complete parsed EndSessionRequest to the implementation, which allows access to additional data. +// It also allows to modify the uri, which will be used for redirection, (e.g. a UI where the user can consent to the logout) +type CanTerminateSessionFromRequest interface { + TerminateSessionFromRequest(ctx context.Context, endSessionRequest *EndSessionRequest) (string, error) +} + type ClientCredentialsStorage interface { ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error) @@ -92,7 +101,7 @@ type TokenExchangeStorage interface { // GetPrivateClaimsFromTokenExchangeRequest will be called during access token creation. // Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc. - GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) (claims map[string]interface{}, err error) + GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) (claims map[string]any, err error) // SetUserinfoFromTokenExchangeRequest will be called during id token creation. // Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc. @@ -102,8 +111,8 @@ type TokenExchangeStorage interface { // TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens // issued by third-party applications. If interface is not implemented - only tokens issued by op will be exchanged. type TokenExchangeTokensVerifierStorage interface { - VerifyExchangeSubjectToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, subject string, tokenClaims map[string]interface{}, err error) - VerifyExchangeActorToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, actor string, tokenClaims map[string]interface{}, err error) + VerifyExchangeSubjectToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, subject string, tokenClaims map[string]any, err error) + VerifyExchangeActorToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, actor string, tokenClaims map[string]any, err error) } var ErrInvalidRefreshToken = errors.New("invalid_refresh_token") @@ -118,7 +127,7 @@ type OPStorage interface { SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error - GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]interface{}, error) + GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) } @@ -136,6 +145,12 @@ type CanSetUserinfoFromRequest interface { SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error } +// CanGetPrivateClaimsFromRequest is an optional additional interface that may be implemented by +// implementors of Storage. It allows setting the jwt token claims based on the request. +type CanGetPrivateClaimsFromRequest interface { + GetPrivateClaimsFromRequest(ctx context.Context, request TokenRequest, restrictedScopes []string) (map[string]any, error) +} + // Storage is a required parameter for NewOpenIDProvider(). In addition to the // embedded interfaces below, if the passed Storage implements ClientCredentialsStorage // then the grant type "client_credentials" will be supported. In that case, the access @@ -152,22 +167,16 @@ type StorageNotFoundError interface { } type EndSessionRequest struct { - UserID string - ClientID string - RedirectURI string + UserID string + ClientID string + IDTokenHintClaims *oidc.IDTokenClaims + RedirectURI string + LogoutHint string + UILocales []language.Tag } var ErrDuplicateUserCode = errors.New("user code already exists") -type DeviceAuthorizationState struct { - ClientID string - Scopes []string - Expires time.Time - Done bool - Subject string - Denied bool -} - type DeviceAuthorizationStorage interface { // StoreDeviceAuthorizationRequest stores a new device authorization request in the database. // User code will be used by the user to complete the login flow and must be unique. @@ -182,18 +191,6 @@ type DeviceAuthorizationStorage interface { // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database. // The method is polled untill the the authorization is eighter Completed, Expired or Denied. GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error) - - // GetDeviceAuthorizationByUserCode resturn the current state of the device authorization flow, - // identified by the user code. - GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*DeviceAuthorizationState, error) - - // CompleteDeviceAuthorization marks a device authorization entry as Completed, - // identified by userCode. The Subject is added to the state, so that - // GetDeviceAuthorizatonState can use it to create a new Access Token. - CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error - - // DenyDeviceAuthorization marks a device authorization entry as Denied. - DenyDeviceAuthorization(ctx context.Context, userCode string) error } func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) { diff --git a/pkg/op/token.go b/pkg/op/token.go index 6dfc993..2e25d05 100644 --- a/pkg/op/token.go +++ b/pkg/op/token.go @@ -2,11 +2,11 @@ package op import ( "context" + "slices" "time" - "github.com/zitadel/oidc/v2/pkg/crypto" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type TokenCreator interface { @@ -28,6 +28,9 @@ type AccessTokenClient interface { } func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) { + ctx, span := tracer.Start(ctx, "CreateTokenResponse") + defer span.End() + var accessToken, newRefreshToken string var validity time.Duration if createAccessToken { @@ -48,7 +51,10 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli if err != nil { return nil, err } - state = authRequest.GetState() + // only implicit flow requires state to be returned. + if code == "" { + state = authRequest.GetState() + } } exp := uint64(validity.Seconds()) @@ -59,10 +65,14 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli TokenType: oidc.BearerToken, ExpiresIn: exp, State: state, + Scope: request.GetScopes(), }, nil } func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) { + ctx, span := tracer.Start(ctx, "createTokens") + defer span.End() + if needsRefreshToken(tokenRequest, client) { return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken) } @@ -73,17 +83,22 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool { switch req := tokenRequest.(type) { case AuthRequest: - return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) + return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken) case TokenExchangeRequest: return req.GetRequestedTokenType() == oidc.RefreshTokenType case RefreshTokenRequest: return true + case *DeviceAuthorizationState: + return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken) default: return false } } func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) { + ctx, span := tracer.Start(ctx, "CreateAccessToken") + defer span.End() + id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client) if err != nil { return "", "", 0, err @@ -97,7 +112,9 @@ func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTok accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage()) return } + _, span = tracer.Start(ctx, "CreateBearerToken") accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto()) + span.End() return } @@ -105,13 +122,20 @@ func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) { return crypto.Encrypt(tokenID + ":" + subject) } +type TokenActorRequest interface { + GetActor() *oidc.ActorClaims +} + func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client AccessTokenClient, storage Storage) (string, error) { + ctx, span := tracer.Start(ctx, "CreateJWT") + defer span.End() + claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew()) if client != nil { restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes()) var ( - privateClaims map[string]interface{} + privateClaims map[string]any err error ) @@ -123,7 +147,11 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex tokenExchangeRequest, ) } else { - privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes)) + if fromRequest, ok := storage.(CanGetPrivateClaimsFromRequest); ok { + privateClaims, err = fromRequest.GetPrivateClaimsFromRequest(ctx, tokenRequest, removeUserinfoScopes(restrictedScopes)) + } else { + privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes)) + } } if err != nil { @@ -131,6 +159,9 @@ func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, ex } claims.Claims = privateClaims } + if actorReq, ok := tokenRequest.(TokenActorRequest); ok { + claims.Actor = actorReq.GetActor() + } signingKey, err := storage.SigningKey(ctx) if err != nil { return "", err @@ -152,6 +183,9 @@ type IDTokenRequest interface { } func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) { + ctx, span := tracer.Start(ctx, "CreateIDToken") + defer span.End() + exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity) var acr, nonce string if authRequest, ok := request.(AuthRequest); ok { @@ -159,6 +193,10 @@ func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, v nonce = authRequest.GetNonce() } claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew()) + if actorReq, ok := request.(TokenActorRequest); ok { + claims.Actor = actorReq.GetActor() + } + scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes()) signingKey, err := storage.SigningKey(ctx) if err != nil { diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go index fc31d57..ddb2fbf 100644 --- a/pkg/op/token_client_credentials.go +++ b/pkg/op/token_client_credentials.go @@ -5,27 +5,31 @@ import ( "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) // ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including // parsing, validating, authorizing the client and finally returning a token func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "ClientCredentialsExchange") + defer span.End() + r = r.WithContext(ctx) + request, err := ParseClientCredentialsRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } @@ -66,6 +70,9 @@ func ParseClientCredentialsRequest(r *http.Request, decoder httphelper.Decoder) // ValidateClientCredentialsRequest validates the client_credentials request parameters including authorization check of the client // and returns a TokenRequest and Client implementation to be used in the client_credentials response, resp. creation of the corresponding access_token. func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (TokenRequest, Client, error) { + ctx, span := tracer.Start(ctx, "ValidateClientCredentialsRequest") + defer span.End() + storage, ok := exchanger.Storage().(ClientCredentialsStorage) if !ok { return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported") @@ -85,6 +92,9 @@ func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientC } func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, storage ClientCredentialsStorage) (Client, error) { + ctx, span := tracer.Start(ctx, "AuthorizeClientCredentialsClient") + defer span.End() + client, err := storage.ClientCredentials(ctx, request.ClientID, request.ClientSecret) if err != nil { return nil, oidc.ErrInvalidClient().WithParent(err) @@ -98,6 +108,9 @@ func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientC } func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) { + ctx, span := tracer.Start(ctx, "CreateClientCredentialsTokenResponse") + defer span.End() + accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "") if err != nil { return nil, err @@ -107,5 +120,6 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke AccessToken: accessToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), + Scope: tokenRequest.GetScopes(), }, nil } diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go index 565a477..155aa43 100644 --- a/pkg/op/token_code.go +++ b/pkg/op/token_code.go @@ -4,29 +4,33 @@ import ( "context" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) // CodeExchange handles the OAuth 2.0 authorization_code grant, including // parsing, validating, authorizing the client and finally exchanging the code for tokens func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "CodeExchange") + defer span.End() + r = r.WithContext(ctx) + tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } if tokenReq.Code == "" { - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger()) return } authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "") if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -45,6 +49,9 @@ func ParseAccessTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc // ValidateAccessTokenRequest validates the token request parameters including authorization check of the client // and returns the previous created auth request corresponding to the auth code func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { + ctx, span := tracer.Start(ctx, "ValidateAccessTokenRequest") + defer span.End() + authReq, client, err := AuthorizeCodeClient(ctx, tokenReq, exchanger) if err != nil { return nil, nil, err @@ -64,6 +71,9 @@ func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenR // AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered. // It than returns the auth request corresponding to the auth code func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (request AuthRequest, client Client, err error) { + ctx, span := tracer.Start(ctx, "AuthorizeCodeClient") + defer span.End() + if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { @@ -88,7 +98,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, if err != nil { return nil, nil, err } - err = AuthorizeCodeChallenge(tokenReq, request.GetCodeChallenge()) + err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge()) return request, client, err } if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() { @@ -104,6 +114,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, // AuthRequestByCode returns the AuthRequest previously created from Storage corresponding to the auth code or an error func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) { + ctx, span := tracer.Start(ctx, "AuthRequestByCode") + defer span.End() + authReq, err := storage.AuthRequestByCode(ctx, code) if err != nil { return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err) diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go index 055ff13..00af485 100644 --- a/pkg/op/token_exchange.go +++ b/pkg/op/token_exchange.go @@ -7,8 +7,8 @@ import ( "strings" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type TokenExchangeRequest interface { @@ -24,12 +24,12 @@ type TokenExchangeRequest interface { GetExchangeSubject() string GetExchangeSubjectTokenType() oidc.TokenType GetExchangeSubjectTokenIDOrToken() string - GetExchangeSubjectTokenClaims() map[string]interface{} + GetExchangeSubjectTokenClaims() map[string]any GetExchangeActor() string GetExchangeActorTokenType() oidc.TokenType GetExchangeActorTokenIDOrToken() string - GetExchangeActorTokenClaims() map[string]interface{} + GetExchangeActorTokenClaims() map[string]any SetCurrentScopes(scopes []string) SetRequestedTokenType(tt oidc.TokenType) @@ -40,12 +40,12 @@ type tokenExchangeRequest struct { exchangeSubjectTokenIDOrToken string exchangeSubjectTokenType oidc.TokenType exchangeSubject string - exchangeSubjectTokenClaims map[string]interface{} + exchangeSubjectTokenClaims map[string]any exchangeActorTokenIDOrToken string exchangeActorTokenType oidc.TokenType exchangeActor string - exchangeActorTokenClaims map[string]interface{} + exchangeActorTokenClaims map[string]any resource []string audience oidc.Audience @@ -96,7 +96,7 @@ func (r *tokenExchangeRequest) GetExchangeSubjectTokenIDOrToken() string { return r.exchangeSubjectTokenIDOrToken } -func (r *tokenExchangeRequest) GetExchangeSubjectTokenClaims() map[string]interface{} { +func (r *tokenExchangeRequest) GetExchangeSubjectTokenClaims() map[string]any { return r.exchangeSubjectTokenClaims } @@ -112,7 +112,7 @@ func (r *tokenExchangeRequest) GetExchangeActorTokenIDOrToken() string { return r.exchangeActorTokenIDOrToken } -func (r *tokenExchangeRequest) GetExchangeActorTokenClaims() map[string]interface{} { +func (r *tokenExchangeRequest) GetExchangeActorTokenClaims() map[string]any { return r.exchangeActorTokenClaims } @@ -134,19 +134,23 @@ func (r *tokenExchangeRequest) SetSubject(subject string) { // TokenExchange handles the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange") func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "TokenExchange") + defer span.End() + r = r.WithContext(ctx) + tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -189,6 +193,9 @@ func ValidateTokenExchangeRequest( clientID, clientSecret string, exchanger Exchanger, ) (TokenExchangeRequest, Client, error) { + ctx, span := tracer.Start(ctx, "ValidateTokenExchangeRequest") + defer span.End() + if oidcTokenExchangeRequest.SubjectToken == "" { return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing") } @@ -197,12 +204,6 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing") } - storage := exchanger.Storage() - teStorage, ok := storage.(TokenExchangeStorage) - if !ok { - return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("token_exchange grant not supported") - } - client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger) if err != nil { return nil, nil, err @@ -220,21 +221,42 @@ func ValidateTokenExchangeRequest( return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported") } + req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger) + if err != nil { + return nil, nil, err + } + return req, client, nil +} + +func CreateTokenExchangeRequest( + ctx context.Context, + oidcTokenExchangeRequest *oidc.TokenExchangeRequest, + client Client, + exchanger Exchanger, +) (TokenExchangeRequest, error) { + ctx, span := tracer.Start(ctx, "CreateTokenExchangeRequest") + defer span.End() + + teStorage, ok := exchanger.Storage().(TokenExchangeStorage) + if !ok { + return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange) + } + exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid") } var ( exchangeActorTokenIDOrToken, exchangeActor string - exchangeActorTokenClaims map[string]interface{} + exchangeActorTokenClaims map[string]any ) if oidcTokenExchangeRequest.ActorToken != "" { exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger, oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true) if !ok { - return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") + return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid") } } @@ -258,17 +280,17 @@ func ValidateTokenExchangeRequest( authTime: time.Now(), } - err = teStorage.ValidateTokenExchangeRequest(ctx, req) + err := teStorage.ValidateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } err = teStorage.CreateTokenExchangeRequest(ctx, req) if err != nil { - return nil, nil, err + return nil, err } - return req, client, nil + return req, nil } func GetTokenIDAndSubjectFromToken( @@ -277,11 +299,17 @@ func GetTokenIDAndSubjectFromToken( token string, tokenType oidc.TokenType, isActor bool, -) (tokenIDOrToken, subject string, claims map[string]interface{}, ok bool) { +) (tokenIDOrToken, subject string, claims map[string]any, ok bool) { + ctx, span := tracer.Start(ctx, "GetTokenIDAndSubjectFromToken") + defer span.End() + switch tokenType { case oidc.AccessTokenType: var accessTokenClaims *oidc.AccessTokenClaims tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token) + if !ok { + break + } claims = accessTokenClaims.Claims case oidc.RefreshTokenType: refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token) @@ -322,6 +350,9 @@ func GetTokenIDAndSubjectFromToken( // AuthorizeTokenExchangeClient authorizes a client by validating the client_id and client_secret func AuthorizeTokenExchangeClient(ctx context.Context, clientID, clientSecret string, exchanger Exchanger) (client Client, err error) { + ctx, span := tracer.Start(ctx, "AuthorizeTokenExchangeClient") + defer span.End() + if err := AuthorizeClientIDSecret(ctx, clientID, clientSecret, exchanger.Storage()); err != nil { return nil, err } @@ -340,6 +371,8 @@ func CreateTokenExchangeResponse( client Client, creator TokenCreator, ) (_ *oidc.TokenExchangeResponse, err error) { + ctx, span := tracer.Start(ctx, "CreateTokenExchangeResponse") + defer span.End() var ( token, refreshToken, tokenType string diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go index 8582388..bb6a5a0 100644 --- a/pkg/op/token_intospection.go +++ b/pkg/op/token_intospection.go @@ -5,15 +5,15 @@ import ( "errors" "net/http" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type Introspector interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } type IntrospectorJWTProfile interface { @@ -28,6 +28,10 @@ func introspectionHandler(introspector Introspector) func(http.ResponseWriter, * } func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) { + ctx, span := tracer.Start(r.Context(), "Introspect") + defer span.End() + r = r.WithContext(ctx) + response := new(oidc.IntrospectionResponse) token, clientID, err := ParseTokenIntrospectionRequest(r, introspector) if err != nil { @@ -65,3 +69,8 @@ func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) return req.Token, clientID, nil } + +type IntrospectionRequest struct { + *ClientCredentials + *oidc.IntrospectionRequest +} diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go index 23bac9a..defb937 100644 --- a/pkg/op/token_jwt_profile.go +++ b/pkg/op/token_jwt_profile.go @@ -5,36 +5,40 @@ import ( "net/http" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type JWTAuthorizationGrantExchanger interface { Exchanger - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1 func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) { + ctx, span := tracer.Start(r.Context(), "JWTProfile") + defer span.End() + r = r.WithContext(ctx) + profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context())) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -56,6 +60,9 @@ func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (* // CreateJWTTokenResponse creates an access_token response for a JWT Profile Grant request // by default the access_token is an opaque string, but can be specified by implementing the JWTProfileTokenStorage interface func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) { + ctx, span := tracer.Start(ctx, "CreateJWTTokenResponse") + defer span.End() + // return an opaque token as default to not break current implementations tokenType := AccessTokenTypeBearer @@ -82,6 +89,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea AccessToken: accessToken, TokenType: oidc.BearerToken, ExpiresIn: uint64(validity.Seconds()), + Scope: tokenRequest.GetScopes(), }, nil } diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go index 148d2a4..a87e883 100644 --- a/pkg/op/token_refresh.go +++ b/pkg/op/token_refresh.go @@ -4,11 +4,11 @@ import ( "context" "errors" "net/http" + "slices" "time" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/strings" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type RefreshTokenRequest interface { @@ -24,18 +24,22 @@ type RefreshTokenRequest interface { // RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including // parsing, validating, authorizing the client and finally exchanging the refresh_token for new tokens func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "RefreshTokenExchange") + defer span.End() + r = r.WithContext(ctx) + tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder()) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) } validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken) if err != nil { - RequestError(w, r, err) + RequestError(w, r, err, exchanger.Logger()) return } httphelper.MarshalJSON(w, resp) @@ -54,6 +58,9 @@ func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oid // ValidateRefreshTokenRequest validates the refresh_token request parameters including authorization check of the client // and returns the data representing the original auth request corresponding to the refresh_token func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) { + ctx, span := tracer.Start(ctx, "ValidateRefreshTokenRequest") + defer span.End() + if tokenReq.RefreshToken == "" { return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing") } @@ -78,7 +85,7 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok return nil } for _, scope := range requestedScopes { - if !strings.Contains(authRequest.GetScopes(), scope) { + if !slices.Contains(authRequest.GetScopes(), scope) { return oidc.ErrInvalidScope() } } @@ -89,6 +96,9 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok // AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered. // It than returns the data representing the original auth request corresponding to the refresh_token func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) { + ctx, span := tracer.Start(ctx, "AuthorizeRefreshClient") + defer span.End() + if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion { jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger) if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() { @@ -131,6 +141,9 @@ func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequ // RefreshTokenRequestByRefreshToken returns the RefreshTokenRequest (data representing the original auth request) // corresponding to the refresh_token from Storage or an error func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) { + ctx, span := tracer.Start(ctx, "RefreshTokenRequestByRefreshToken") + defer span.End() + request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken) if err != nil { return nil, oidc.ErrInvalidGrant().WithParent(err) diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go index b9e9805..3f5af7a 100644 --- a/pkg/op/token_request.go +++ b/pkg/op/token_request.go @@ -2,11 +2,12 @@ package op import ( "context" + "log/slog" "net/http" "net/url" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type Exchanger interface { @@ -20,18 +21,26 @@ type Exchanger interface { GrantTypeJWTAuthorizationSupported() bool GrantTypeClientCredentialsSupported() bool GrantTypeDeviceCodeSupported() bool - AccessTokenVerifier(context.Context) AccessTokenVerifier - IDTokenHintVerifier(context.Context) IDTokenHintVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier + IDTokenHintVerifier(context.Context) *IDTokenHintVerifier + Logger() *slog.Logger } func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - Exchange(w, r, exchanger) + ctx, span := tracer.Start(r.Context(), "tokenHandler") + defer span.End() + + Exchange(w, r.WithContext(ctx), exchanger) } } // Exchange performs a token exchange appropriate for the grant type func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + ctx, span := tracer.Start(r.Context(), "Exchange") + r = r.WithContext(ctx) + defer span.End() + grantType := r.FormValue("grant_type") switch grantType { case string(oidc.GrantTypeCode): @@ -63,10 +72,10 @@ func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { return } case "": - RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing")) + RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger()) return } - RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType)) + RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger()) } // AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest @@ -79,6 +88,10 @@ type AuthenticatedTokenRequest interface { // ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either // HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, request AuthenticatedTokenRequest) error { + ctx, span := tracer.Start(r.Context(), "ParseAuthenticatedTokenRequest") + defer span.End() + r = r.WithContext(ctx) + err := r.ParseForm() if err != nil { return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err) @@ -106,6 +119,9 @@ func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, // AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST) func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error { + ctx, span := tracer.Start(ctx, "AuthorizeClientIDSecret") + defer span.End() + err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) if err != nil { return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err) @@ -115,12 +131,20 @@ func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, // AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent // code_challenge of the auth request (PKCE) -func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.CodeChallenge) error { - if tokenReq.CodeVerifier == "" { - return oidc.ErrInvalidRequest().WithDescription("code_challenge required") +func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error { + if challenge == nil { + if codeVerifier != "" { + return oidc.ErrInvalidRequest().WithDescription("code_verifier unexpectedly provided") + } + + return nil } - if !oidc.VerifyCodeChallenge(challenge, tokenReq.CodeVerifier) { - return oidc.ErrInvalidGrant().WithDescription("invalid code challenge") + + if codeVerifier == "" { + return oidc.ErrInvalidRequest().WithDescription("code_verifier required") + } + if !oidc.VerifyCodeChallenge(challenge, codeVerifier) { + return oidc.ErrInvalidGrant().WithDescription("invalid code_verifier") } return nil } @@ -128,6 +152,9 @@ func AuthorizeCodeChallenge(tokenReq *oidc.AccessTokenRequest, challenge *oidc.C // AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously // registered public key (JWT Profile) func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) { + ctx, span := tracer.Start(ctx, "AuthorizePrivateJWTKey") + defer span.End() + jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier(ctx)) if err != nil { return nil, err diff --git a/pkg/op/token_request_test.go b/pkg/op/token_request_test.go new file mode 100644 index 0000000..d226af6 --- /dev/null +++ b/pkg/op/token_request_test.go @@ -0,0 +1,75 @@ +package op_test + +import ( + "testing" + + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/stretchr/testify/assert" +) + +func TestAuthorizeCodeChallenge(t *testing.T) { + tests := []struct { + name string + codeVerifier string + codeChallenge *oidc.CodeChallenge + want func(t *testing.T, err error) + }{ + { + name: "missing both code_verifier and code_challenge", + codeVerifier: "", + codeChallenge: nil, + want: func(t *testing.T, err error) { + assert.Nil(t, err) + }, + }, + { + name: "valid code_verifier", + codeVerifier: "Hello World!", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.Nil(t, err) + }, + }, + { + name: "invalid code_verifier", + codeVerifier: "Hi World!", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "invalid code_verifier") + }, + }, + { + name: "code_verifier provided without code_challenge", + codeVerifier: "code_verifier", + codeChallenge: nil, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "code_verifier unexpectedly provided") + }, + }, + { + name: "empty code_verifier", + codeVerifier: "", + codeChallenge: &oidc.CodeChallenge{ + Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk", + Method: oidc.CodeChallengeMethodS256, + }, + want: func(t *testing.T, err error) { + assert.ErrorContains(t, err, "code_verifier required") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := op.AuthorizeCodeChallenge(tt.codeVerifier, tt.codeChallenge) + + tt.want(t, err) + }) + } +} diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go index 58332c3..049ee15 100644 --- a/pkg/op/token_revocation.go +++ b/pkg/op/token_revocation.go @@ -7,22 +7,22 @@ import ( "net/url" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type Revoker interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier AuthMethodPrivateKeyJWTSupported() bool AuthMethodPostSupported() bool } type RevokerJWTProfile interface { Revoker - JWTProfileVerifier(context.Context) JWTProfileVerifier + JWTProfileVerifier(context.Context) *JWTProfileVerifier } func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) { @@ -32,6 +32,10 @@ func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) } func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { + ctx, span := tracer.Start(r.Context(), "Revoke") + r = r.WithContext(ctx) + defer span.End() + token, tokenTypeHint, clientID, err := ParseTokenRevocationRequest(r, revoker) if err != nil { RevocationRequestError(w, r, err) @@ -68,6 +72,10 @@ func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) { } func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) { + ctx, span := tracer.Start(r.Context(), "ParseTokenRevocationRequest") + r = r.WithContext(ctx) + defer span.End() + err = r.ParseForm() if err != nil { return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err) @@ -131,6 +139,11 @@ func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, token } func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { + statusErr := RevocationError(err) + httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode) +} + +func RevocationError(err error) StatusError { e := oidc.DefaultToServerError(err, err.Error()) status := http.StatusBadRequest switch e.ErrorType { @@ -139,10 +152,13 @@ func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) { case oidc.ServerError: status = 500 } - httphelper.MarshalJSONWithStatus(w, e, status) + return NewStatusError(e, status) } func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + ctx, span := tracer.Start(ctx, "getTokenIDAndSubjectForRevocation") + defer span.End() + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { splitToken := strings.Split(tokenIDSubject, ":") diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go index 21a0af4..ff75e72 100644 --- a/pkg/op/userinfo.go +++ b/pkg/op/userinfo.go @@ -6,15 +6,15 @@ import ( "net/http" "strings" - httphelper "github.com/zitadel/oidc/v2/pkg/http" - "github.com/zitadel/oidc/v2/pkg/oidc" + httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) type UserinfoProvider interface { Decoder() httphelper.Decoder Crypto() Crypto Storage() Storage - AccessTokenVerifier(context.Context) AccessTokenVerifier + AccessTokenVerifier(context.Context) *AccessTokenVerifier } func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) { @@ -24,6 +24,10 @@ func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter } func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) { + ctx, span := tracer.Start(r.Context(), "Userinfo") + r = r.WithContext(ctx) + defer span.End() + accessToken, err := ParseUserinfoRequest(r, userinfoProvider.Decoder()) if err != nil { http.Error(w, "access token missing", http.StatusUnauthorized) @@ -44,6 +48,10 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP } func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) { + ctx, span := tracer.Start(r.Context(), "ParseUserinfoRequest") + r = r.WithContext(ctx) + defer span.End() + accessToken, err := getAccessToken(r) if err == nil { return accessToken, nil @@ -61,6 +69,10 @@ func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, } func getAccessToken(r *http.Request) (string, error) { + ctx, span := tracer.Start(r.Context(), "getAccessToken") + r = r.WithContext(ctx) + defer span.End() + authHeader := r.Header.Get("authorization") if authHeader == "" { return "", errors.New("no auth header") @@ -73,6 +85,9 @@ func getAccessToken(r *http.Request) (string, error) { } func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) { + ctx, span := tracer.Start(ctx, "getTokenIDAndSubject") + defer span.End() + tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken) if err == nil { splitToken := strings.Split(tokenIDSubject, ":") diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go index 9a8b912..585ca54 100644 --- a/pkg/op/verifier_access_token.go +++ b/pkg/op/verifier_access_token.go @@ -2,62 +2,25 @@ package op import ( "context" - "time" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) -type AccessTokenVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet -} +type AccessTokenVerifier oidc.Verifier -type accessTokenVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - keySet oidc.KeySet -} - -// Issuer implements oidc.Verifier interface -func (i *accessTokenVerifier) Issuer() string { - return i.issuer -} - -// MaxAgeIAT implements oidc.Verifier interface -func (i *accessTokenVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -// Offset implements oidc.Verifier interface -func (i *accessTokenVerifier) Offset() time.Duration { - return i.offset -} - -// SupportedSignAlgs implements AccessTokenVerifier interface -func (i *accessTokenVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -// KeySet implements AccessTokenVerifier interface -func (i *accessTokenVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -type AccessTokenVerifierOpt func(*accessTokenVerifier) +type AccessTokenVerifierOpt func(*AccessTokenVerifier) func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt { - return func(verifier *accessTokenVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *AccessTokenVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) AccessTokenVerifier { - verifier := &accessTokenVerifier{ - issuer: issuer, - keySet: keySet, +// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification. +func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier { + verifier := &AccessTokenVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -66,7 +29,10 @@ func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTok } // VerifyAccessToken validates the access token (issuer, signature and expiration). -func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v AccessTokenVerifier) (claims C, err error) { +func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) { + ctx, span := tracer.Start(ctx, "VerifyAccessToken") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -78,15 +44,15 @@ func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v Acces return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { return nilClaims, err } diff --git a/pkg/op/verifier_access_token_example_test.go b/pkg/op/verifier_access_token_example_test.go index effdd58..b97a7fd 100644 --- a/pkg/op/verifier_access_token_example_test.go +++ b/pkg/op/verifier_access_token_example_test.go @@ -4,9 +4,9 @@ import ( "context" "fmt" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" - "github.com/zitadel/oidc/v2/pkg/op" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" ) // MyCustomClaims extends the TokenClaims base, diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go index 62c26a9..5845f9f 100644 --- a/pkg/op/verifier_access_token_test.go +++ b/pkg/op/verifier_access_token_test.go @@ -5,10 +5,10 @@ import ( "testing" "time" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" ) func TestNewAccessTokenVerifier(t *testing.T) { @@ -20,7 +20,7 @@ func TestNewAccessTokenVerifier(t *testing.T) { tests := []struct { name string args args - want AccessTokenVerifier + want *AccessTokenVerifier }{ { name: "simple", @@ -28,9 +28,9 @@ func TestNewAccessTokenVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +42,10 @@ func TestNewAccessTokenVerifier(t *testing.T) { WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"), }, }, - want: &accessTokenVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -58,12 +58,12 @@ func TestNewAccessTokenVerifier(t *testing.T) { } func TestVerifyAccessToken(t *testing.T) { - verifier := &accessTokenVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - keySet: tu.KeySet{}, + verifier := &AccessTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, } tests := []struct { diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go index d906075..02610aa 100644 --- a/pkg/op/verifier_id_token_hint.go +++ b/pkg/op/verifier_id_token_hint.go @@ -2,69 +2,25 @@ package op import ( "context" - "time" + "errors" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) -type IDTokenHintVerifier interface { - oidc.Verifier - SupportedSignAlgs() []string - KeySet() oidc.KeySet - ACR() oidc.ACRVerifier - MaxAge() time.Duration -} +type IDTokenHintVerifier oidc.Verifier -type idTokenHintVerifier struct { - issuer string - maxAgeIAT time.Duration - offset time.Duration - supportedSignAlgs []string - maxAge time.Duration - acr oidc.ACRVerifier - keySet oidc.KeySet -} - -func (i *idTokenHintVerifier) Issuer() string { - return i.issuer -} - -func (i *idTokenHintVerifier) MaxAgeIAT() time.Duration { - return i.maxAgeIAT -} - -func (i *idTokenHintVerifier) Offset() time.Duration { - return i.offset -} - -func (i *idTokenHintVerifier) SupportedSignAlgs() []string { - return i.supportedSignAlgs -} - -func (i *idTokenHintVerifier) KeySet() oidc.KeySet { - return i.keySet -} - -func (i *idTokenHintVerifier) ACR() oidc.ACRVerifier { - return i.acr -} - -func (i *idTokenHintVerifier) MaxAge() time.Duration { - return i.maxAge -} - -type IDTokenHintVerifierOpt func(*idTokenHintVerifier) +type IDTokenHintVerifierOpt func(*IDTokenHintVerifier) func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt { - return func(verifier *idTokenHintVerifier) { - verifier.supportedSignAlgs = algs + return func(verifier *IDTokenHintVerifier) { + verifier.SupportedSignAlgs = algs } } -func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) IDTokenHintVerifier { - verifier := &idTokenHintVerifier{ - issuer: issuer, - keySet: keySet, +func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier { + verifier := &IDTokenHintVerifier{ + Issuer: issuer, + KeySet: keySet, } for _, opt := range opts { opt(verifier) @@ -72,9 +28,27 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi return verifier } +type IDTokenHintExpiredError struct { + error +} + +func (e IDTokenHintExpiredError) Unwrap() error { + return e.error +} + +func (e IDTokenHintExpiredError) Is(err error) bool { + return errors.Is(err, e.error) +} + // VerifyIDTokenHint validates the id token according to -// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation -func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTokenHintVerifier) (claims C, err error) { +// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. +// In case of an expired token both the Claims and first encountered expiry related error +// is returned of type [IDTokenHintExpiredError]. In that case the caller can choose to still +// trust the token for cases like logout, as signature and other verifications succeeded. +func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) { + ctx, span := tracer.Start(ctx, "VerifyIDTokenHint") + defer span.End() + var nilClaims C decrypted, err := oidc.DecryptToken(token) @@ -86,28 +60,28 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v IDTok return nilClaims, err } - if err := oidc.CheckIssuer(claims, v.Issuer()); err != nil { + if err := oidc.CheckIssuer(claims, v.Issuer); err != nil { return nilClaims, err } - if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs(), v.KeySet()); err != nil { + if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil { return nilClaims, err } - if err = oidc.CheckExpiration(claims, v.Offset()); err != nil { + if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil { return nilClaims, err } - if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT(), v.Offset()); err != nil { - return nilClaims, err + if err = oidc.CheckExpiration(claims, v.Offset); err != nil { + return claims, IDTokenHintExpiredError{err} } - if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR()); err != nil { - return nilClaims, err + if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil { + return claims, IDTokenHintExpiredError{err} } - if err = oidc.CheckAuthTime(claims, v.MaxAge()); err != nil { - return nilClaims, err + if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil { + return claims, IDTokenHintExpiredError{err} } return claims, nil } diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go index f4d0b0c..347e33c 100644 --- a/pkg/op/verifier_id_token_hint_test.go +++ b/pkg/op/verifier_id_token_hint_test.go @@ -2,13 +2,14 @@ package op import ( "context" + "errors" "testing" "time" + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - tu "github.com/zitadel/oidc/v2/internal/testutil" - "github.com/zitadel/oidc/v2/pkg/oidc" ) func TestNewIDTokenHintVerifier(t *testing.T) { @@ -20,7 +21,7 @@ func TestNewIDTokenHintVerifier(t *testing.T) { tests := []struct { name string args args - want IDTokenHintVerifier + want *IDTokenHintVerifier }{ { name: "simple", @@ -28,9 +29,9 @@ func TestNewIDTokenHintVerifier(t *testing.T) { issuer: tu.ValidIssuer, keySet: tu.KeySet{}, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, }, }, { @@ -42,10 +43,10 @@ func TestNewIDTokenHintVerifier(t *testing.T) { WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"), }, }, - want: &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - keySet: tu.KeySet{}, - supportedSignAlgs: []string{"ABC", "DEF"}, + want: &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + KeySet: tu.KeySet{}, + SupportedSignAlgs: []string{"ABC", "DEF"}, }, }, } @@ -57,35 +58,44 @@ func TestNewIDTokenHintVerifier(t *testing.T) { } } +func Test_IDTokenHintExpiredError(t *testing.T) { + var err error = IDTokenHintExpiredError{oidc.ErrExpired} + assert.True(t, errors.Unwrap(err) == oidc.ErrExpired) + assert.ErrorIs(t, err, oidc.ErrExpired) + assert.ErrorAs(t, err, &IDTokenHintExpiredError{}) +} + func TestVerifyIDTokenHint(t *testing.T) { - verifier := &idTokenHintVerifier{ - issuer: tu.ValidIssuer, - maxAgeIAT: 2 * time.Minute, - offset: time.Second, - supportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, - maxAge: 2 * time.Minute, - acr: tu.ACRVerify, - keySet: tu.KeySet{}, + verifier := &IDTokenHintVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + KeySet: tu.KeySet{}, } tests := []struct { name string tokenClaims func() (string, *oidc.IDTokenClaims) - wantErr bool + wantClaims bool + wantErr error }{ { name: "success", tokenClaims: tu.ValidIDToken, + wantClaims: true, }, { name: "parse err", tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil }, - wantErr: true, + wantErr: oidc.ErrParse, }, { name: "invalid signature", tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil }, - wantErr: true, + wantErr: oidc.ErrSignatureUnsupportedAlg, }, { name: "wrong issuer", @@ -96,29 +106,7 @@ func TestVerifyIDTokenHint(t *testing.T) { tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, - }, - { - name: "expired", - tokenClaims: func() (string, *oidc.IDTokenClaims) { - return tu.NewIDToken( - tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, - tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, - tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", - ) - }, - wantErr: true, - }, - { - name: "wrong IAT", - tokenClaims: func() (string, *oidc.IDTokenClaims) { - return tu.NewIDToken( - tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, - tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, - tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "", - ) - }, - wantErr: true, + wantErr: oidc.ErrIssuerInvalid, }, { name: "wrong acr", @@ -129,7 +117,31 @@ func TestVerifyIDTokenHint(t *testing.T) { "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, + wantErr: oidc.ErrAcrInvalid, + }, + { + name: "expired", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", + ) + }, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrExpired}, + }, + { + name: "IAT too old", + tokenClaims: func() (string, *oidc.IDTokenClaims) { + return tu.NewIDToken( + tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience, + tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce, + tu.ValidACR, tu.ValidAMR, tu.ValidClientID, time.Hour, "", + ) + }, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrIatToOld}, }, { name: "expired auth", @@ -140,7 +152,8 @@ func TestVerifyIDTokenHint(t *testing.T) { tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "", ) }, - wantErr: true, + wantClaims: true, + wantErr: IDTokenHintExpiredError{oidc.ErrAuthTimeToOld}, }, } for _, tt := range tests { @@ -148,14 +161,12 @@ func TestVerifyIDTokenHint(t *testing.T) { token, want := tt.tokenClaims() got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier) - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, got) + require.ErrorIs(t, err, tt.wantErr) + if tt.wantClaims { + assert.Equal(t, got, want, "claims") return } - require.NoError(t, err) - require.NotNil(t, got) - assert.Equal(t, got, want) + assert.Nil(t, got, "claims") }) } } diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go index 4d83c59..85bfb14 100644 --- a/pkg/op/verifier_jwt_profile.go +++ b/pkg/op/verifier_jwt_profile.go @@ -6,33 +6,41 @@ import ( "fmt" "time" - "gopkg.in/square/go-jose.v2" + jose "github.com/go-jose/go-jose/v4" - "github.com/zitadel/oidc/v2/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" ) -type JWTProfileVerifier interface { +// JWTProfileVerfiier extends oidc.Verifier with +// a jwtProfileKeyStorage and a function to check +// the subject in a token. +type JWTProfileVerifier struct { oidc.Verifier - Storage() jwtProfileKeyStorage - CheckSubject(request *oidc.JWTTokenRequest) error -} - -type jwtProfileVerifier struct { - storage jwtProfileKeyStorage - subjectCheck func(request *oidc.JWTTokenRequest) error - issuer string - maxAgeIAT time.Duration - offset time.Duration + Storage JWTProfileKeyStorage + keySet oidc.KeySet + CheckSubject func(request *oidc.JWTTokenRequest) error } // NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) -func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) JWTProfileVerifier { - j := &jwtProfileVerifier{ - storage: storage, - subjectCheck: SubjectIsIssuer, - issuer: issuer, - maxAgeIAT: maxAgeIAT, - offset: offset, +func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + return newJWTProfileVerifier(storage, nil, issuer, maxAgeIAT, offset, opts...) +} + +// NewJWTProfileVerifierKeySet creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication) +func NewJWTProfileVerifierKeySet(keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + return newJWTProfileVerifier(nil, keySet, issuer, maxAgeIAT, offset, opts...) +} + +func newJWTProfileVerifier(storage JWTProfileKeyStorage, keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier { + j := &JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: issuer, + MaxAgeIAT: maxAgeIAT, + Offset: offset, + }, + Storage: storage, + keySet: keySet, + CheckSubject: SubjectIsIssuer, } for _, opt := range opts { @@ -42,53 +50,38 @@ func NewJWTProfileVerifier(storage jwtProfileKeyStorage, issuer string, maxAgeIA return j } -type JWTProfileVerifierOption func(*jwtProfileVerifier) +type JWTProfileVerifierOption func(*JWTProfileVerifier) +// SubjectCheck sets a custom function to check the subject. +// Defaults to SubjectIsIssuer() func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption { - return func(verifier *jwtProfileVerifier) { - verifier.subjectCheck = check + return func(verifier *JWTProfileVerifier) { + verifier.CheckSubject = check } } -func (v *jwtProfileVerifier) Issuer() string { - return v.issuer -} - -func (v *jwtProfileVerifier) Storage() jwtProfileKeyStorage { - return v.storage -} - -func (v *jwtProfileVerifier) MaxAgeIAT() time.Duration { - return v.maxAgeIAT -} - -func (v *jwtProfileVerifier) Offset() time.Duration { - return v.offset -} - -func (v *jwtProfileVerifier) CheckSubject(request *oidc.JWTTokenRequest) error { - return v.subjectCheck(request) -} - // VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication) // // checks audience, exp, iat, signature and that issuer and sub are the same -func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { +func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) { + ctx, span := tracer.Start(ctx, "VerifyJWTAssertion") + defer span.End() + request := new(oidc.JWTTokenRequest) payload, err := oidc.ParseToken(assertion, request) if err != nil { return nil, err } - if err = oidc.CheckAudience(request, v.Issuer()); err != nil { + if err = oidc.CheckAudience(request, v.Issuer); err != nil { return nil, err } - if err = oidc.CheckExpiration(request, v.Offset()); err != nil { + if err = oidc.CheckExpiration(request, v.Offset); err != nil { return nil, err } - if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT(), v.Offset()); err != nil { + if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil { return nil, err } @@ -96,17 +89,21 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v JWTProfileVerif return nil, err } - keySet := &jwtProfileKeySet{storage: v.Storage(), clientID: request.Issuer} + keySet := v.keySet + if keySet == nil { + keySet = &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer} + } if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil { return nil, err } return request, nil } -type jwtProfileKeyStorage interface { +type JWTProfileKeyStorage interface { GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) } +// SubjectIsIssuer func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { if request.Issuer != request.Subject { return errors.New("delegation not allowed, issuer and sub must be identical") @@ -115,12 +112,15 @@ func SubjectIsIssuer(request *oidc.JWTTokenRequest) error { } type jwtProfileKeySet struct { - storage jwtProfileKeyStorage + storage JWTProfileKeyStorage clientID string } // VerifySignature implements oidc.KeySet by getting the public key from Storage implementation func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) { + ctx, span := tracer.Start(ctx, "VerifySignature") + defer span.End() + keyID, _ := oidc.GetKeyIDAndAlg(jws) key, err := k.storage.GetKeyByIDAndClientID(ctx, keyID, k.clientID) if err != nil { diff --git a/pkg/op/verifier_jwt_profile_test.go b/pkg/op/verifier_jwt_profile_test.go new file mode 100644 index 0000000..2068678 --- /dev/null +++ b/pkg/op/verifier_jwt_profile_test.go @@ -0,0 +1,117 @@ +package op_test + +import ( + "context" + "testing" + "time" + + tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc" + "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewJWTProfileVerifier(t *testing.T) { + want := &op.JWTProfileVerifier{ + Verifier: oidc.Verifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: time.Minute, + Offset: time.Second, + }, + Storage: tu.JWTProfileKeyStorage{}, + } + got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error { + return oidc.ErrSubjectMissing + })) + assert.Equal(t, want.Verifier, got.Verifier) + assert.Equal(t, want.Storage, got.Storage) + assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing) +} + +func TestVerifyJWTAssertion(t *testing.T) { + errCtx, cancel := context.WithCancel(context.Background()) + cancel() + + verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0) + tests := []struct { + name string + ctx context.Context + newToken func() (string, *oidc.JWTTokenRequest) + wantErr bool + }{ + { + name: "parse error", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil }, + wantErr: true, + }, + { + name: "wrong audience", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{"wrong"}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "expired", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now(), time.Now().Add(-time.Hour), + ) + }, + wantErr: true, + }, + { + name: "invalid iat", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer}, + time.Now().Add(time.Hour), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "invalid subject", + ctx: context.Background(), + newToken: func() (string, *oidc.JWTTokenRequest) { + return tu.NewJWTProfileAssertion( + tu.ValidClientID, "wrong", []string{tu.ValidIssuer}, + time.Now(), tu.ValidExpiration, + ) + }, + wantErr: true, + }, + { + name: "check signature fail", + ctx: errCtx, + newToken: tu.ValidJWTProfileAssertion, + wantErr: true, + }, + { + name: "ok", + ctx: context.Background(), + newToken: tu.ValidJWTProfileAssertion, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertion, want := tt.newToken() + got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, want, got) + }) + } +} diff --git a/pkg/strings/strings.go b/pkg/strings/strings.go index af48cf3..b8f43a1 100644 --- a/pkg/strings/strings.go +++ b/pkg/strings/strings.go @@ -1,10 +1,9 @@ package strings +import "slices" + +// Deprecated: Use standard library [slices.Contains] instead. func Contains(list []string, needle string) bool { - for _, item := range list { - if item == needle { - return true - } - } - return false + // TODO(v4): remove package. + return slices.Contains(list, needle) }