diff --git a/.codecov/codecov.yml b/.codecov/codecov.yml
index 8104953..0f17a71 100644
--- a/.codecov/codecov.yml
+++ b/.codecov/codecov.yml
@@ -1,4 +1,5 @@
codecov:
+ branch: main
notify:
require_ci_to_pass: yes
coverage:
@@ -19,4 +20,7 @@ parsers:
comment:
layout: "header, diff"
behavior: default
- require_changes: no
\ No newline at end of file
+ require_changes: no
+ignore:
+ - "example"
+ - "**/mock"
diff --git a/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml b/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml
new file mode 100644
index 0000000..d024341
--- /dev/null
+++ b/.forgejo.bak/ISSUE_TEMPLATE/bug_report.yaml
@@ -0,0 +1,57 @@
+name: Bug Report
+description: "Create a bug report to help us improve ZITADEL. Click [here](https://github.com/zitadel/zitadel/blob/main/CONTRIBUTING.md#product-management) to see how we process your issue."
+title: "[Bug]: "
+labels: ["bug"]
+type: Bug
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thanks for taking the time to fill out this bug report!
+ - type: checkboxes
+ id: preflight
+ attributes:
+ label: Preflight Checklist
+ options:
+ - label:
+ I could not find a solution in the documentation, the existing issues or discussions
+ required: true
+ - label:
+ I have joined the [ZITADEL chat](https://zitadel.com/chat)
+ - type: input
+ id: version
+ attributes:
+ label: Version
+ description: Which version of the OIDC library are you using.
+ - type: textarea
+ id: impact
+ attributes:
+ label: Describe the problem caused by this bug
+ description: A clear and concise description of the problem you have and what the bug is.
+ validations:
+ required: true
+ - type: textarea
+ id: reproduce
+ attributes:
+ label: To reproduce
+ description: Steps to reproduce the behaviour
+ placeholder: |
+ Steps to reproduce the behavior:
+ validations:
+ required: true
+ - type: textarea
+ id: screenshots
+ attributes:
+ label: Screenshots
+ description: If applicable, add screenshots to help explain your problem.
+ - type: textarea
+ id: expected
+ attributes:
+ label: Expected behavior
+ description: A clear and concise description of what you expected to happen.
+ placeholder: As a [type of user], I want [some goal] so that [some reason].
+ - type: textarea
+ id: additional
+ attributes:
+ label: Additional Context
+ description: Please add any other infos that could be useful.
diff --git a/.forgejo.bak/ISSUE_TEMPLATE/config.yml b/.forgejo.bak/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000..a49eab2
--- /dev/null
+++ b/.forgejo.bak/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: true
\ No newline at end of file
diff --git a/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml b/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml
new file mode 100644
index 0000000..d3f82b9
--- /dev/null
+++ b/.forgejo.bak/ISSUE_TEMPLATE/docs.yaml
@@ -0,0 +1,31 @@
+name: đ Documentation
+description: Create an issue for missing or wrong documentation.
+labels: ["docs"]
+type: task
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thanks for taking the time to fill out this issue.
+ - type: checkboxes
+ id: preflight
+ attributes:
+ label: Preflight Checklist
+ options:
+ - label:
+ I could not find a solution in the existing issues, docs, nor discussions
+ required: true
+ - label:
+ I have joined the [ZITADEL chat](https://zitadel.com/chat)
+ - type: textarea
+ id: docs
+ attributes:
+ label: Describe the docs your are missing or that are wrong
+ placeholder: As a [type of user], I want [some goal] so that [some reason].
+ validations:
+ required: true
+ - type: textarea
+ id: additional
+ attributes:
+ label: Additional Context
+ description: Please add any other infos that could be useful.
diff --git a/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml b/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml
new file mode 100644
index 0000000..ef2103e
--- /dev/null
+++ b/.forgejo.bak/ISSUE_TEMPLATE/enhancement.yaml
@@ -0,0 +1,55 @@
+name: đ ī¸ Improvement
+description: "Create an new issue for an improvment in ZITADEL"
+labels: ["enhancement"]
+type: enhancement
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thanks for taking the time to fill out this proposal / feature reqeust
+ - type: checkboxes
+ id: preflight
+ attributes:
+ label: Preflight Checklist
+ options:
+ - label:
+ I could not find a solution in the existing issues, docs, nor discussions
+ required: true
+ - label:
+ I have joined the [ZITADEL chat](https://zitadel.com/chat)
+ - type: textarea
+ id: problem
+ attributes:
+ label: Describe your problem
+ description: Please describe your problem this improvement is supposed to solve.
+ placeholder: Describe the problem you have
+ validations:
+ required: true
+ - type: textarea
+ id: solution
+ attributes:
+ label: Describe your ideal solution
+ description: Which solution do you propose?
+ placeholder: As a [type of user], I want [some goal] so that [some reason].
+ validations:
+ required: true
+ - type: input
+ id: version
+ attributes:
+ label: Version
+ description: Which version of the OIDC Library are you using.
+ - type: dropdown
+ id: environment
+ attributes:
+ label: Environment
+ description: How do you use ZITADEL?
+ options:
+ - ZITADEL Cloud
+ - Self-hosted
+ validations:
+ required: true
+ - type: textarea
+ id: additional
+ attributes:
+ label: Additional Context
+ description: Please add any other infos that could be useful.
diff --git a/.forgejo.bak/dependabot.yml b/.forgejo.bak/dependabot.yml
new file mode 100644
index 0000000..1efdcf8
--- /dev/null
+++ b/.forgejo.bak/dependabot.yml
@@ -0,0 +1,25 @@
+version: 2
+updates:
+- package-ecosystem: gomod
+ directory: "/"
+ schedule:
+ interval: daily
+ time: '04:00'
+ open-pull-requests-limit: 10
+ commit-message:
+ prefix: chore
+ include: scope
+- package-ecosystem: gomod
+ target-branch: "2.12.x"
+ directory: "/"
+ schedule:
+ interval: daily
+ time: '04:00'
+ open-pull-requests-limit: 10
+ commit-message:
+ prefix: chore
+ include: scope
+- package-ecosystem: "github-actions"
+ directory: "/"
+ schedule:
+ interval: weekly
\ No newline at end of file
diff --git a/.forgejo.bak/pull_request_template.md b/.forgejo.bak/pull_request_template.md
new file mode 100644
index 0000000..6c4ae58
--- /dev/null
+++ b/.forgejo.bak/pull_request_template.md
@@ -0,0 +1,16 @@
+### Definition of Ready
+
+- [ ] I am happy with the code
+- [ ] Short description of the feature/issue is added in the pr description
+- [ ] PR is linked to the corresponding user story
+- [ ] Acceptance criteria are met
+- [ ] All open todos and follow ups are defined in a new ticket and justified
+- [ ] Deviations from the acceptance criteria and design are agreed with the PO and documented.
+- [ ] No debug or dead code
+- [ ] My code has no repetitions
+- [ ] Critical parts are tested automatically
+- [ ] Where possible E2E tests are implemented
+- [ ] Documentation/examples are up-to-date
+- [ ] All non-functional requirements are met
+- [ ] Functionality of the acceptance criteria is checked manually on the dev system.
+
diff --git a/.github/workflows/codeql-analysis.yml b/.forgejo.bak/workflows/codeql-analysis.yml
similarity index 86%
rename from .github/workflows/codeql-analysis.yml
rename to .forgejo.bak/workflows/codeql-analysis.yml
index 0101ea5..27fa244 100644
--- a/.github/workflows/codeql-analysis.yml
+++ b/.forgejo.bak/workflows/codeql-analysis.yml
@@ -2,10 +2,10 @@ name: "Code scanning - action"
on:
push:
- branches: [master, ]
+ branches: [main,next]
pull_request:
# The branches below must be a subset of the branches above
- branches: [master]
+ branches: [main,next]
schedule:
- cron: '0 11 * * 0'
@@ -16,7 +16,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
with:
# We must fetch at least the immediate parents so that if this is
# a pull request then we can checkout the head.
@@ -29,7 +29,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
- uses: github/codeql-action/init@v1
+ uses: github/codeql-action/init@v3
# Override language selection by uncommenting this and choosing your languages
with:
languages: go
@@ -37,7 +37,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
- uses: github/codeql-action/autobuild@v1
+ uses: github/codeql-action/autobuild@v3
# âšī¸ Command-line programs to run using the OS shell.
# đ https://git.io/JvXDl
@@ -51,4 +51,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
- uses: github/codeql-action/analyze@v1
+ uses: github/codeql-action/analyze@v3
diff --git a/.forgejo.bak/workflows/issue.yml b/.forgejo.bak/workflows/issue.yml
new file mode 100644
index 0000000..480c339
--- /dev/null
+++ b/.forgejo.bak/workflows/issue.yml
@@ -0,0 +1,43 @@
+name: Add new issues to product management project
+
+on:
+ issues:
+ types:
+ - opened
+ pull_request_target:
+ types:
+ - opened
+
+jobs:
+ add-to-project:
+ name: Add issue and community pr to project
+ runs-on: ubuntu-latest
+ steps:
+ - name: add issue
+ uses: actions/add-to-project@v1.0.2
+ if: ${{ github.event_name == 'issues' }}
+ with:
+ # You can target a repository in a different organization
+ # to the issue
+ project-url: https://github.com/orgs/zitadel/projects/2
+ github-token: ${{ secrets.ADD_TO_PROJECT_PAT }}
+ - uses: tspascoal/get-user-teams-membership@v3
+ id: checkUserMember
+ if: github.actor != 'dependabot[bot]'
+ with:
+ username: ${{ github.actor }}
+ GITHUB_TOKEN: ${{ secrets.ADD_TO_PROJECT_PAT }}
+ - name: add pr
+ uses: actions/add-to-project@v1.0.2
+ if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'engineers')}}
+ with:
+ # You can target a repository in a different organization
+ # to the issue
+ project-url: https://github.com/orgs/zitadel/projects/2
+ github-token: ${{ secrets.ADD_TO_PROJECT_PAT }}
+ - uses: actions-ecosystem/action-add-labels@v1.1.3
+ if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'staff')}}
+ with:
+ github_token: ${{ secrets.ADD_TO_PROJECT_PAT }}
+ labels: |
+ os-contribution
diff --git a/.forgejo.bak/workflows/release.yml b/.forgejo.bak/workflows/release.yml
new file mode 100644
index 0000000..00063e4
--- /dev/null
+++ b/.forgejo.bak/workflows/release.yml
@@ -0,0 +1,49 @@
+name: Release
+on:
+ push:
+ branches:
+ - "2.11.x"
+ - main
+ - next
+ tags-ignore:
+ - '**'
+ pull_request:
+ branches:
+ - '**'
+ workflow_dispatch:
+
+jobs:
+ test:
+ runs-on: ubuntu-24.04
+ strategy:
+ fail-fast: false
+ matrix:
+ go: ['1.23', '1.24']
+ name: Go ${{ matrix.go }} test
+ steps:
+ - uses: actions/checkout@v4
+ - name: Setup go
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go }}
+ - run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
+ - uses: codecov/codecov-action@v5.4.3
+ with:
+ file: ./profile.cov
+ name: codecov-go
+ release:
+ runs-on: ubuntu-24.04
+ needs: [test]
+ if: ${{ github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next' }}
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ steps:
+ - name: Source checkout
+ uses: actions/checkout@v4
+ - name: Semantic Release
+ uses: cycjimmy/semantic-release-action@v4
+ with:
+ dry_run: false
+ semantic_version: 18.0.1
+ extra_plugins: |
+ @semantic-release/exec@6.0.3
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
deleted file mode 100644
index 3d24fe9..0000000
--- a/.github/dependabot.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-version: 2
-updates:
-- package-ecosystem: gomod
- directory: "/"
- schedule:
- interval: daily
- time: '04:00'
- open-pull-requests-limit: 10
- commit-message:
- prefix: chore
- include: scope
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
deleted file mode 100644
index 42c3ab0..0000000
--- a/.github/workflows/release.yml
+++ /dev/null
@@ -1,33 +0,0 @@
-name: Release
-on: push
-jobs:
- test:
- runs-on: ubuntu-18.04
- strategy:
- matrix:
- go: ['1.11', '1.12', '1.13', '1.14']
- name: Go ${{ matrix.go }} test
- steps:
- - uses: actions/checkout@v2
- - name: Setup go
- uses: actions/setup-go@v2-beta
- with:
- go-version: ${{ matrix.go }}
- - run: go test -race -v -coverprofile=profile.cov ./pkg/...
- - uses: codecov/codecov-action@v1
- with:
- file: ./profile.cov
- name: codecov-go
- release:
- runs-on: ubuntu-18.04
- needs: [test]
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- steps:
- - name: Source checkout
- uses: actions/checkout@v1
- with:
- fetch-depth: 1
- - name: Create Version
- uses: caos/semantic-release@v0.2.4
-
diff --git a/.releaserc.js b/.releaserc.js
index d9c7f99..c87b1d1 100644
--- a/.releaserc.js
+++ b/.releaserc.js
@@ -1,8 +1,12 @@
module.exports = {
- branch: 'master',
+ branches: [
+ {name: "2.11.x"},
+ {name: "main"},
+ {name: "next", prerelease: true},
+ ],
plugins: [
"@semantic-release/commit-analyzer",
"@semantic-release/release-notes-generator",
"@semantic-release/github"
]
- };
\ No newline at end of file
+};
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000..b107ae4
--- /dev/null
+++ b/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+abuse@zitadel.ch.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..8861b9c
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,40 @@
+# How to contribute to the OIDC SDK for Go
+
+## Did you find a bug?
+
+Please file an issue [here](https://github.com/zitadel/oidc/issues/new?assignees=&labels=bug&template=bug_report.md&title=).
+
+Bugs are evaluated every day as soon as possible.
+
+## Enhancement
+
+Do you miss a feature? Please file an issue [here](https://github.com/zitadel/oidc/issues/new?assignees=&labels=enhancement&template=feature_request.md&title=)
+
+Enhancements are discussed and evaluated every Wednesday by the ZITADEL core team.
+
+## Grab an Issues
+
+We add the label "good first issue" for problems we think are a good starting point to contribute to the OIDC SDK.
+
+* [Issues for first time contributors](https://github.com/zitadel/oidc/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22)
+* [All issues](https://github.com/zitadel/oidc/issues)
+
+### Make a PR
+
+If you like to contribute fork the OIDC repository. After you implemented the new feature create a PullRequest in the OIDC reposiotry.
+
+Make sure you use semantic release:
+
+* feat: New Feature
+* fix: Bug Fix
+* docs: Documentation
+
+## Want to use the library?
+
+Checkout the [examples folder](example) for different client and server implementations.
+
+Or checkout how we use it ourselves in our OpenSource Identity and Access Management [ZITADEL](https://github.com/zitadel/zitadel).
+
+## **Did you find a security flaw?**
+
+* Please read [Security Policy](SECURITY.md).
\ No newline at end of file
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000..a5f5f7a
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1 @@
+Copyright The zitadel/oidc Contributors
diff --git a/README.md b/README.md
index c1d7919..bc346f5 100644
--- a/README.md
+++ b/README.md
@@ -1,58 +1,192 @@
# OpenID Connect SDK (client and server) for Go
[](https://github.com/semantic-release/semantic-release)
-[](https://github.com/caos/oidc/actions)
-[](https://github.com/caos/oidc/blob/master/LICENSE)
-[](https://github.com/caos/oidc/releases)
-[](https://goreportcard.com/report/github.com/caos/oidc)
-[](https://codecov.io/gh/caos/oidc)
+[](https://github.com/zitadel/oidc/actions)
+[](https://pkg.go.dev/github.com/zitadel/oidc/v3)
+[](https://github.com/zitadel/oidc/blob/master/LICENSE)
+[](https://github.com/zitadel/oidc/releases)
+[](https://goreportcard.com/report/github.com/zitadel/oidc/v3)
+[](https://codecov.io/gh/zitadel/oidc)
-> This project is in alpha state. It can AND will continue breaking until version 1.0.0 is released
+[](https://openid.net/certification/)
## What Is It
-This project is a easy to use client and server implementation for the `OIDC` (Open ID Connect) standard written for `Go`.
+This project is an easy-to-use client (RP) and server (OP) implementation for the `OIDC` (OpenID Connect) standard written for `Go`.
+
+The RP is certified for the [basic](https://www.certification.openid.net/plan-detail.html?public=true&plan=uoprP0OO8Z4Qo) and [config](https://www.certification.openid.net/plan-detail.html?public=true&plan=AYSdLbzmWbu9X) profile.
Whenever possible we tried to reuse / extend existing packages like `OAuth2 for Go`.
+## Basic Overview
+
+The most important packages of the library:
+
+
+/pkg
+ /client clients using the OP for retrieving, exchanging and verifying tokens
+ /rp definition and implementation of an OIDC Relying Party (client)
+ /rs definition and implementation of an OAuth Resource Server (API)
+ /op definition and implementation of an OIDC OpenID Provider (server)
+ /oidc definitions shared by clients and server
+
+/example
+ /client/api example of an api / resource server implementation using token introspection
+ /client/app web app / RP demonstrating authorization code flow using various authentication methods (code, PKCE, JWT profile)
+ /client/github example of the extended OAuth2 library, providing an HTTP client with a reuse token source
+ /client/service demonstration of JWT Profile Authorization Grant
+ /server examples of an OpenID Provider implementations (including dynamic) with some very basic login UI
+
+
+### Semver
+
+This package uses [semver](https://semver.org/) for [releases](https://github.com/zitadel/oidc/releases). Major releases ship breaking changes. Starting with the `v2` to `v3` increment we provide an [upgrade guide](UPGRADING.md) to ease migration to a newer version.
+
## How To Use It
-TBD
+Check the `/example` folder where example code for different scenarios is located.
+
+```bash
+# start oidc op server
+# oidc discovery http://localhost:9998/.well-known/openid-configuration
+go run github.com/zitadel/oidc/v3/example/server
+# start oidc web client (in a new terminal)
+CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://localhost:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
+```
+
+- open http://localhost:9999/login in your browser
+- you will be redirected to op server and the login UI
+- login with user `test-user@localhost` and password `verysecure`
+- the OP will redirect you to the client app, which displays the user info
+
+for the dynamic issuer, just start it with:
+
+```bash
+go run github.com/zitadel/oidc/v3/example/server/dynamic
+```
+
+the oidc web client above will still work, but if you add `oidc.local` (pointing to 127.0.0.1) in your hosts file you can also start it with:
+
+```bash
+CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid profile" PORT=9999 go run github.com/zitadel/oidc/v3/example/client/app
+```
+
+> Note: Usernames are suffixed with the hostname (`test-user@localhost` or `test-user@oidc.local`)
+
+### Server configuration
+
+Example server allows extra configuration using environment variables and could be used for end to
+end testing of your services.
+
+| Name | Format | Description |
+| ------------ | -------------------------------- | ------------------------------------- |
+| PORT | Number between 1 and 65535 | OIDC listen port |
+| REDIRECT_URI | Comma-separated URIs | List of allowed redirect URIs |
+| USERS_FILE | Path to json in local filesystem | Users with their data and credentials |
+
+Here is json equivalent for one of the default users
+
+```json
+{
+ "id2": {
+ "ID": "id2",
+ "Username": "test-user2",
+ "Password": "verysecure",
+ "FirstName": "Test",
+ "LastName": "User2",
+ "Email": "test-user2@zitadel.ch",
+ "EmailVerified": true,
+ "Phone": "",
+ "PhoneVerified": false,
+ "PreferredLanguage": "DE",
+ "IsAdmin": false
+ }
+}
+```
## Features
-| | Code Flow | Implicit Flow | Hybrid Flow | Discovery | PKCE | Token Exchange | mTLS |
-|----------------|-----------|---------------|-------------|-----------|------|----------------|---------|
-| Relaying Party | yes | yes | not yet | yes | yes | partial | not yet |
-| Origin Party | yes | yes | not yet | yes | yes | not yet | not yet |
+| | Relying party | OpenID Provider | Specification |
+| -------------------- | ------------- | --------------- | -------------------------------------------- |
+| Code Flow | yes | yes | OpenID Connect Core 1.0, [Section 3.1][1] |
+| Implicit Flow | no[^1] | yes | OpenID Connect Core 1.0, [Section 3.2][2] |
+| Hybrid Flow | no | not yet | OpenID Connect Core 1.0, [Section 3.3][3] |
+| Client Credentials | yes | yes | OpenID Connect Core 1.0, [Section 9][4] |
+| Refresh Token | yes | yes | OpenID Connect Core 1.0, [Section 12][5] |
+| Discovery | yes | yes | OpenID Connect [Discovery][6] 1.0 |
+| JWT Profile | yes | yes | [RFC 7523][7] |
+| PKCE | yes | yes | [RFC 7636][8] |
+| Token Exchange | yes | yes | [RFC 8693][9] |
+| Device Authorization | yes | yes | [RFC 8628][10] |
+| mTLS | not yet | not yet | [RFC 8705][11] |
+| Back-Channel Logout | not yet | yes | OpenID Connect [Back-Channel Logout][12] 1.0 |
+
+[1]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth "3.1. Authentication using the Authorization Code Flow"
+[2]: https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth "3.2. Authentication using the Implicit Flow"
+[3]: https://openid.net/specs/openid-connect-core-1_0.html#HybridFlowAuth "3.3. Authentication using the Hybrid Flow"
+[4]: https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication "9. Client Authentication"
+[5]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens "12. Using Refresh Tokens"
+[6]: https://openid.net/specs/openid-connect-discovery-1_0.html "OpenID Connect Discovery 1.0 incorporating errata set 1"
+[7]: https://www.rfc-editor.org/rfc/rfc7523.html "JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants"
+[8]: https://www.rfc-editor.org/rfc/rfc7636.html "Proof Key for Code Exchange by OAuth Public Clients"
+[9]: https://www.rfc-editor.org/rfc/rfc8693.html "OAuth 2.0 Token Exchange"
+[10]: https://www.rfc-editor.org/rfc/rfc8628.html "OAuth 2.0 Device Authorization Grant"
+[11]: https://www.rfc-editor.org/rfc/rfc8705.html "OAuth 2.0 Mutual-TLS Client Authentication and Certificate-Bound Access Tokens"
+[12]: https://openid.net/specs/openid-connect-backchannel-1_0.html "OpenID Connect Back-Channel Logout 1.0 incorporating errata set 1"
+
+## Contributors
+
+
+
+
+
+Made with [contrib.rocks](https://contrib.rocks).
### Resources
-For your convinience you can find the relevant standards linked below.
+For your convenience you can find the relevant guides linked below.
- [OpenID Connect Core 1.0 incorporating errata set 1](https://openid.net/specs/openid-connect-core-1_0.html)
-- [Proof Key for Code Exchange by OAuth Public Clients](https://tools.ietf.org/html/rfc7636)
-- [OAuth 2.0 Token Exchange](https://tools.ietf.org/html/draft-ietf-oauth-token-exchange-19)
-- [OAuth 2.0 Mutual-TLS Client Authentication and Certificate-Bound Access Tokens](https://tools.ietf.org/html/draft-ietf-oauth-mtls-17)
+- [OIDC/OAuth Flow in Zitadel (using this library)](https://zitadel.com/docs/guides/integrate/login-users)
## Supported Go Versions
+For security reasons, we only support and recommend the use of one of the latest two Go versions (:white_check_mark:).
+Versions that also build are marked with :warning:.
+
| Version | Supported |
-|---------|--------------------|
-| <1.11 | :x: |
-| 1.11 | :white_check_mark: |
-| 1.12 | :white_check_mark: |
-| 1.13 | :white_check_mark: |
-| 1.14 | :white_check_mark: |
+| ------- | ------------------ |
+| <1.23 | :x: |
+| 1.23 | :white_check_mark: |
+| 1.24 | :white_check_mark: |
## Why another library
-As of 2020 there are not a lot of `OIDC` librarys in `Go` which can handle server and client implementations. CAOS is strongly commited to the general field of IAM (Identity and Access Management) and as such, we need solid frameworks to implement services.
+As of 2020 there are not a lot of `OIDC` library's in `Go` which can handle server and client implementations. ZITADEL is strongly committed to the general field of IAM (Identity and Access Management) and as such, we need solid frameworks to implement services.
+
+### Goals
+
+- [Certify this library as OP](https://openid.net/certification/#OPs)
+
+### Other Go OpenID Connect libraries
+
+[https://github.com/coreos/go-oidc](https://github.com/coreos/go-oidc)
+
+The `go-oidc` does only support `RP` and is not feasible to use as `OP` that's why we could not rely on `go-oidc`
+
+[https://github.com/ory/fosite](https://github.com/ory/fosite)
+
+We did not choose `fosite` because it implements `OAuth 2.0` on its own and does not rely on the golang provided package. Nonetheless this is a great project.
## License
-The full functionality of this library is and stays open source and free to use for everyone. Visit our [website](https://caos.ch) and get in touch.
+The full functionality of this library is and stays open source and free to use for everyone. Visit
+our [website](https://zitadel.com) and get in touch.
-See the exact licensing terms [here](./LICENSE)
+See the exact licensing terms [here](LICENSE)
-Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
+Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "
+AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
+language governing permissions and limitations under the License.
+
+[^1]: https://github.com/zitadel/oidc/issues/135#issuecomment-950563892
diff --git a/SECURITY.md b/SECURITY.md
index 6fe2daa..a32b842 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -1,42 +1,20 @@
# Security Policy
-At CAOS we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team.
+Please refer to the security policy [on zitadel/zitadel](https://github.com/zitadel/zitadel/blob/main/SECURITY.md) which is applicable for all open source repositories of our organization.
## Supported Versions
-After the initial Release the following version support will apply
+We currently support the following version of the OIDC framework:
-| Version | Supported |
-| ------- | ------------------ |
-| 1.x.x | :white_check_mark: (not yet available) |
-| 0.x.x | :x: |
+| Version | Supported | Branch | Details |
+| -------- | ------------------ | ----------- | ------------------------------------ |
+| 0.x.x | :x: | | not maintained |
+| <2.11 | :x: | | not maintained |
+| 2.11.x | :lock: :warning: | [2.11.x][1] | security only, [community effort][2] |
+| 3.x.x | :heavy_check_mark: | [main][3] | supported |
+| 4.0.0-xx | :white_check_mark: | [next][4] | [development branch] |
-## Reporting a vulnerability
-
-To file a incident, please disclose by email to security@caos.ch with the security details.
-
-At the moment GPG encryption is no yet supported, however you may sign your message at will.
-
-### When should I report a vulnerability
-
-* You think you discovered a ...
- * ... potential security vulnerability in the SDK
- * ... vulnerability in another project that this SDK bases on
-* For projects with their own vulnerability reporting and disclosure process, please report it directly there
-
-### When should I NOT report a vulnerability
-
-* You need help applying security related updates
-* Your issue is not security related
-
-## Security Vulnerability Response
-
-TBD
-
-## Public Disclosure
-
-All accepted and mitigated vulnerabilitys will be published on the [Github Security Page](https://github.com/caos/oidc/security/advisories)
-
-### Timing
-
-We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the discloures the time frame can range from 7 to 90 days.
+[1]: https://github.com/zitadel/oidc/tree/2.11.x
+[2]: https://github.com/zitadel/oidc/discussions/458
+[3]: https://github.com/zitadel/oidc/tree/main
+[4]: https://github.com/zitadel/oidc/tree/next
diff --git a/UPGRADING.md b/UPGRADING.md
new file mode 100644
index 0000000..6b5a41d
--- /dev/null
+++ b/UPGRADING.md
@@ -0,0 +1,370 @@
+# Upgrading
+
+All commands are executed from the root of the project that imports oidc packages.
+`sed` commands are created with **GNU sed** in mind and might need alternate syntax
+on non-GNU systems, such as MacOS.
+Alternatively, GNU sed can be installed on such systems. (`coreutils` package?).
+
+## V2 to V3
+
+**TL;DR** at the [bottom](#full-script) of this chapter is a full `sed` script
+containing all automatic steps at once.
+
+
+As first steps we will:
+1. Download the latest v3 module;
+2. Replace imports in all Go files;
+3. Tidy the module file;
+
+```bash
+go get -u github.com/zitadel/oidc/v3
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/github\.com\/zitadel\/oidc\/v2/github.com\/zitadel\/oidc\/v3/g'
+go mod tidy
+```
+
+### global
+
+#### go-jose package
+
+`gopkg.in/square/go-jose.v2` import has been changed to `github.com/go-jose/go-jose/v3`.
+That means that the imported types are also changed and imports need to be adapted.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/gopkg.in\/square\/go-jose\.v2/github.com\/go-jose\/go-jose\/v3/g'
+go mod tidy
+```
+
+### op
+
+```go
+import "github.com/zitadel/oidc/v3/pkg/op"
+```
+
+#### Logger
+
+This version of OIDC adds logging to the framework. For this we use the new Go standard library `log/slog`. (Until v3.12.0 we used `x/exp/slog`).
+Mostly OIDC will use error level logs where it's returning an error through a HTTP handler. OIDC errors that are user facing don't carry much context, also for security reasons. With logging we are now able to print the error context, so that developers can more easily find the source of their issues. Previously we just discarded such context.
+
+Most users of the OP package with the storage interface will not experience breaking changes. However if you use `RequestError()` directly in your code, you now need to give it a `Logger` as final argument.
+
+The `OpenIDProvider` and sub-interfaces like `Authorizer` and `Exchanger` got a `Logger()` method to return the configured logger. This logger is in turn used by `AuthRequestError()`. You configure the logger with the `WithLogger()` for the `Provider`. By default the `slog.Default()` is used.
+
+We also provide a new optional interface: [`LogAuthRequest`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#LogAuthRequest). If an `AuthRequest` implements this interface, it is completely passed into the logger after an error. Its `LogValue()` will be used by `slog` to print desired fields. This allows omitting sensitive fields you wish not no print. If the interface is not implemented, no `AuthRequest` details will ever be printed.
+
+#### Server interface
+
+We've added a new [`Server`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#Server) interface. This interface is experimental and subject to change. See [issue 440](https://github.com/zitadel/oidc/issues/440) for the motivation and discussion around this new interface.
+Usage of the new interface is not required, but may be used for advanced scenarios when working with the `Storage` interface isn't the optimal solution for your app (like we experienced in [Zitadel](https://github.com/zitadel/zitadel)).
+
+#### AuthRequestError
+
+`AuthRequestError` now takes the complete `Authorizer` as final argument, instead of only the encoder.
+This is to facilitate the use of the `Logger` as described above.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/\bAuthRequestError(w, r, authReq, err, authorizer.Encoder())/AuthRequestError(w, r, authReq, err, authorizer)/g'
+```
+
+Note: the sed regex might not find all uses if the local variables of the passed arguments use different names.
+
+#### AccessTokenVerifier
+
+`AccessTokenVerifier` interface has become a struct type. `NewAccessTokenVerifier` now returns a pointer to `AccessTokenVerifier`.
+Variable and struct fields declarations need to be changed from `op.AccessTokenVerifier` to `*op.AccessTokenVerifier`.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/\bop\.AccessTokenVerifier\b/*op.AccessTokenVerifier/g'
+```
+
+#### JWTProfileVerifier
+
+`JWTProfileVerifier` interface has become a struct type. `NewJWTProfileVerifier` now returns a pointer to `JWTProfileVerifier`.
+Variable and struct fields declarations need to be changed from `op.JWTProfileVerifier` to `*op.JWTProfileVerifier`.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/\bop\.JWTProfileVerifier\b/*op.JWTProfileVerifier/g'
+```
+
+#### IDTokenHintVerifier
+
+`IDTokenHintVerifier` interface has become a struct type. `NewIDTokenHintVerifier` now returns a pointer to `IDTokenHintVerifier`.
+Variable and struct fields declarations need to be changed from `op.IDTokenHintVerifier` to `*op.IDTokenHintVerifier`.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/\bop\.IDTokenHintVerifier\b/*op.IDTokenHintVerifier/g'
+```
+
+#### ParseRequestObject
+
+`ParseRequestObject` no longer returns `*oidc.AuthRequest` as it already operates on the pointer for the passed `authReq` argument. As such the argument and the return value were the same pointer. Callers can just use the original `*oidc.AuthRequest` now.
+
+#### Endpoint Configuration
+
+`Endpoint`s returned from `Configuration` interface methods are now pointers. Usually, `op.Provider` is the main implementation of the `Configuration` interface. However, if a custom implementation is used, you should be able to update it using the following:
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/AuthorizationEndpoint() Endpoint/AuthorizationEndpoint() *Endpoint/g' \
+ -e 's/TokenEndpoint() Endpoint/TokenEndpoint() *Endpoint/g' \
+ -e 's/IntrospectionEndpoint() Endpoint/IntrospectionEndpoint() *Endpoint/g' \
+ -e 's/UserinfoEndpoint() Endpoint/UserinfoEndpoint() *Endpoint/g' \
+ -e 's/RevocationEndpoint() Endpoint/RevocationEndpoint() *Endpoint/g' \
+ -e 's/EndSessionEndpoint() Endpoint/EndSessionEndpoint() *Endpoint/g' \
+ -e 's/KeysEndpoint() Endpoint/KeysEndpoint() *Endpoint/g' \
+ -e 's/DeviceAuthorizationEndpoint() Endpoint/DeviceAuthorizationEndpoint() *Endpoint/g'
+```
+
+#### CreateDiscoveryConfig
+
+`CreateDiscoveryConfig` now takes a context as first argument. The following adds `context.TODO()` to the function:
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/op\.CreateDiscoveryConfig(/op.CreateDiscoveryConfig(context.TODO(), /g'
+```
+
+It now takes the issuer out of the context using the [`IssuerFromContext`](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/op#IssuerFromContext) functionality,
+instead of the `config.IssuerFromRequest()` method.
+
+#### CreateRouter
+
+`CreateRouter` now returns a `chi.Router` instead of `*mux.Router`.
+Usually this function is called when the Provider is constructed and not by package consumers.
+However if your project does call this function directly, manual update of the code is required.
+
+#### DeviceAuthorizationStorage
+
+`DeviceAuthorizationStorage` dropped the following methods:
+
+- `GetDeviceAuthorizationByUserCode`
+- `CompleteDeviceAuthorization`
+- `DenyDeviceAuthorization`
+
+These methods proved not to be required from a library point of view.
+Implementations of a device authorization flow may take care of these calls in a way they see fit.
+
+#### AuthorizeCodeChallenge
+
+The `AuthorizeCodeChallenge` function now only takes the `CodeVerifier` argument, instead of the complete `*oidc.AccessTokenRequest`.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/op\.AuthorizeCodeChallenge(tokenReq/op.AuthorizeCodeChallenge(tokenReq.CodeVerifier/g'
+```
+
+### client
+
+```go
+import "github.com/zitadel/oidc/v3/pkg/client"
+```
+
+#### Context
+
+All client calls now take a context as first argument. The following adds `context.TODO()` to all the affected functions:
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/client\.Discover(/client.Discover(context.TODO(), /g' \
+ -e 's/client\.CallTokenEndpoint(/client.CallTokenEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallEndSessionEndpoint(/client.CallEndSessionEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallRevokeEndpoint(/client.CallRevokeEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallTokenExchangeEndpoint(/client.CallTokenExchangeEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallDeviceAuthorizationEndpoint(/client.CallDeviceAuthorizationEndpoint(context.TODO(), /g' \
+ -e 's/client\.JWTProfileExchange(/client.JWTProfileExchange(context.TODO(), /g'
+```
+
+#### keyFile type
+
+The `keyFile` struct type is now exported a `KeyFile` and returned by the `ConfigFromKeyFile` and `ConfigFromKeyFileData`. No changes are needed on the caller's side.
+
+### client/profile
+
+The package now defines a new interface `TokenSource` which compliments the `oauth2.TokenSource` with a `TokenCtx` method, so that a context can be explicitly added on each call. Users can migrate to the new method when they whish.
+
+`NewJWTProfileTokenSource` now takes a context as first argument, so do the related `NewJWTProfileTokenSourceFromKeyFile` and `NewJWTProfileTokenSourceFromKeyFileData`. The context is used for the Discovery request.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/profile\.NewJWTProfileTokenSource(/profile.NewJWTProfileTokenSource(context.TODO(), /g' \
+ -e 's/profile\.NewJWTProfileTokenSourceFromKeyFileData(/profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), /g' \
+ -e 's/profile\.NewJWTProfileTokenSourceFromKeyFile(/profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), /g'
+```
+
+
+### client/rp
+
+```go
+import "github.com/zitadel/oidc/v3/pkg/client/rs"
+```
+
+#### Discover
+
+The `Discover` function has been removed. Use `client.Discover` instead.
+
+#### Context
+
+Most `rp` functions now require a context as first argument. The following adds `context.TODO()` to the function that have no additional changes. Functions with more complex changes are documented below.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/rp\.NewRelyingPartyOIDC(/rp.NewRelyingPartyOIDC(context.TODO(), /g' \
+ -e 's/rp\.EndSession(/rp.EndSession(context.TODO(), /g' \
+ -e 's/rp\.RevokeToken(/rp.RevokeToken(context.TODO(), /g' \
+ -e 's/rp\.DeviceAuthorization(/rp.DeviceAuthorization(context.TODO(), /g'
+```
+
+Remember to replace `context.TODO()` with a context that is applicable for your app, where possible.
+
+#### RefreshAccessToken
+
+1. Renamed to `RefreshTokens`;
+2. A context must be passed;
+3. An `*oidc.Tokens` object is now returned, which included an ID Token if it was returned by the server;
+4. The function is now generic and requires a type argument for the `IDTokenClaims` implementation inside the returned `oidc.Tokens` object;
+
+For most use cases `*oidc.IDTokenClaims` can be used as type argument. A custom implementation of `oidc.IDClaims` can be used if type-safe access to custom claims is required.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/rp\.RefreshAccessToken(/rp.RefreshTokens[*oidc.IDTokenClaims](context.TODO(), /g'
+```
+
+Users that called `tokens.Extra("id_token").(string)` and a subsequent `VerifyTokens` to get the claims, no longer need to do this. The ID token is verified (when present) by `RefreshTokens` already.
+
+
+#### Userinfo
+
+1. A context must be passed as first argument;
+2. The function is now generic and requires a type argument for the returned user info object;
+
+For most use cases `*oidc.UserInfo` can be used a type argument. A [custom implementation](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/client/rp#example-Userinfo-Custom) of `rp.SubjectGetter` can be used if type-safe access to custom claims is required.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/rp\.Userinfo(/rp.Userinfo[*oidc.UserInfo](context.TODO(), /g'
+```
+
+#### UserinfoCallback
+
+`UserinfoCallback` has an additional type argument fot the `UserInfo` object. Typically the type argument can be inferred by the compiler, by the function that is passed. The actual code update cannot be done by a simple `sed` script and depends on how the caller implemented the function.
+
+
+#### IDTokenVerifier
+
+`IDTokenVerifier` interface has become a struct type. `NewIDTokenVerifier` now returns a pointer to `IDTokenVerifier`.
+Variable and struct fields declarations need to be changed from `rp.IDTokenVerifier` to `*rp.AccessTokenVerifier`.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/\brp\.IDTokenVerifier\b/*rp.IDTokenVerifier/g'
+```
+
+### client/rs
+
+```go
+import "github.com/zitadel/oidc/v3/pkg/client/rs"
+```
+
+#### NewResourceServer
+
+The `NewResourceServerClientCredentials` and `NewResourceServerJWTProfile` constructor functions now take a context as first argument.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/rs\.NewResourceServerClientCredentials(/rs.NewResourceServerClientCredentials(context.TODO(), /g' \
+ -e 's/rs\.NewResourceServerJWTProfile(/rs.NewResourceServerJWTProfile(context.TODO(), /g'
+```
+
+#### Introspect
+
+`Introspect` is now generic and requires a type argument for the returned introspection response. For most use cases `*oidc.IntrospectionResponse` can be used as type argument. Any other response type if type-safe access to [custom claims](https://pkg.go.dev/github.com/zitadel/oidc/v3/pkg/client/rs#example-Introspect-Custom) is required.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/rs\.Introspect(/rs.Introspect[*oidc.IntrospectionResponse](/g'
+```
+
+### client/tokenexchange
+
+The `TokenExchanger` constructor functions `NewTokenExchanger` and `NewTokenExchangerClientCredentials` now take a context as first argument.
+As well as the `ExchangeToken` function.
+
+```bash
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/tokenexchange\.NewTokenExchanger(/tokenexchange.NewTokenExchanger(context.TODO(), /g' \
+ -e 's/tokenexchange\.NewTokenExchangerClientCredentials(/tokenexchange.NewTokenExchangerClientCredentials(context.TODO(), /g' \
+ -e 's/tokenexchange\.ExchangeToken(/tokenexchange.ExchangeToken(context.TODO(), /g'
+```
+
+### oidc
+
+#### SpaceDelimitedArray
+
+The `SpaceDelimitedArray` type's `Encode()` function has been renamed to `String()` so it implements the `fmt.Stringer` interface. If the `Encode` method was called by a package consumer, it should be changed manually.
+
+#### Verifier
+
+The `Verifier` interface as been changed into a struct type. The struct type is aliased in the `op` and `rp` packages for the specific token use cases. See the relevant section above.
+
+### Full script
+
+For the courageous this is the full `sed` script which combines all the steps described above.
+It should migrate most of the code in a repository to a more-or-less compilable state,
+using defaults such as `context.TODO()` where possible.
+
+Warnings:
+- Again, this is written for **GNU sed** not the posix variant.
+- Assumes imports that use the package names, not aliases.
+- Do this on a project with version control (eg Git), that allows you to rollback if things went wrong.
+- The script has been tested on the [ZITADEL](https://github.com/zitadel/zitadel) project, but we do not use all affected symbols. Parts of the script are mere guesswork.
+
+```bash
+go get -u github.com/zitadel/oidc/v3
+find . -type f -name '*.go' | xargs sed -i \
+ -e 's/github\.com\/zitadel\/oidc\/v2/github.com\/zitadel\/oidc\/v3/g' \
+ -e 's/gopkg.in\/square\/go-jose\.v2/github.com\/go-jose\/go-jose\/v3/g' \
+ -e 's/\bAuthRequestError(w, r, authReq, err, authorizer.Encoder())/AuthRequestError(w, r, authReq, err, authorizer)/g' \
+ -e 's/\bop\.AccessTokenVerifier\b/*op.AccessTokenVerifier/g' \
+ -e 's/\bop\.JWTProfileVerifier\b/*op.JWTProfileVerifier/g' \
+ -e 's/\bop\.IDTokenHintVerifier\b/*op.IDTokenHintVerifier/g' \
+ -e 's/AuthorizationEndpoint() Endpoint/AuthorizationEndpoint() *Endpoint/g' \
+ -e 's/TokenEndpoint() Endpoint/TokenEndpoint() *Endpoint/g' \
+ -e 's/IntrospectionEndpoint() Endpoint/IntrospectionEndpoint() *Endpoint/g' \
+ -e 's/UserinfoEndpoint() Endpoint/UserinfoEndpoint() *Endpoint/g' \
+ -e 's/RevocationEndpoint() Endpoint/RevocationEndpoint() *Endpoint/g' \
+ -e 's/EndSessionEndpoint() Endpoint/EndSessionEndpoint() *Endpoint/g' \
+ -e 's/KeysEndpoint() Endpoint/KeysEndpoint() *Endpoint/g' \
+ -e 's/DeviceAuthorizationEndpoint() Endpoint/DeviceAuthorizationEndpoint() *Endpoint/g' \
+ -e 's/op\.CreateDiscoveryConfig(/op.CreateDiscoveryConfig(context.TODO(), /g' \
+ -e 's/op\.AuthorizeCodeChallenge(tokenReq/op.AuthorizeCodeChallenge(tokenReq.CodeVerifier/g' \
+ -e 's/client\.Discover(/client.Discover(context.TODO(), /g' \
+ -e 's/client\.CallTokenEndpoint(/client.CallTokenEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallEndSessionEndpoint(/client.CallEndSessionEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallRevokeEndpoint(/client.CallRevokeEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallTokenExchangeEndpoint(/client.CallTokenExchangeEndpoint(context.TODO(), /g' \
+ -e 's/client\.CallDeviceAuthorizationEndpoint(/client.CallDeviceAuthorizationEndpoint(context.TODO(), /g' \
+ -e 's/client\.JWTProfileExchange(/client.JWTProfileExchange(context.TODO(), /g' \
+ -e 's/profile\.NewJWTProfileTokenSource(/profile.NewJWTProfileTokenSource(context.TODO(), /g' \
+ -e 's/profile\.NewJWTProfileTokenSourceFromKeyFileData(/profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), /g' \
+ -e 's/profile\.NewJWTProfileTokenSourceFromKeyFile(/profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), /g' \
+ -e 's/rp\.NewRelyingPartyOIDC(/rp.NewRelyingPartyOIDC(context.TODO(), /g' \
+ -e 's/rp\.EndSession(/rp.EndSession(context.TODO(), /g' \
+ -e 's/rp\.RevokeToken(/rp.RevokeToken(context.TODO(), /g' \
+ -e 's/rp\.DeviceAuthorization(/rp.DeviceAuthorization(context.TODO(), /g' \
+ -e 's/rp\.RefreshAccessToken(/rp.RefreshTokens[*oidc.IDTokenClaims](context.TODO(), /g' \
+ -e 's/rp\.Userinfo(/rp.Userinfo[*oidc.UserInfo](context.TODO(), /g' \
+ -e 's/\brp\.IDTokenVerifier\b/*rp.IDTokenVerifier/g' \
+ -e 's/rs\.NewResourceServerClientCredentials(/rs.NewResourceServerClientCredentials(context.TODO(), /g' \
+ -e 's/rs\.NewResourceServerJWTProfile(/rs.NewResourceServerJWTProfile(context.TODO(), /g' \
+ -e 's/rs\.Introspect(/rs.Introspect[*oidc.IntrospectionResponse](/g' \
+ -e 's/tokenexchange\.NewTokenExchanger(/tokenexchange.NewTokenExchanger(context.TODO(), /g' \
+ -e 's/tokenexchange\.NewTokenExchangerClientCredentials(/tokenexchange.NewTokenExchangerClientCredentials(context.TODO(), /g' \
+ -e 's/tokenexchange\.ExchangeToken(/tokenexchange.ExchangeToken(context.TODO(), /g'
+go mod tidy
+```
\ No newline at end of file
diff --git a/example/client/api/api.go b/example/client/api/api.go
index 6e1b0bd..69f9466 100644
--- a/example/client/api/api.go
+++ b/example/client/api/api.go
@@ -1,90 +1,104 @@
package main
-// import (
-// "encoding/json"
-// "fmt"
-// "log"
-// "net/http"
-// "os"
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "strings"
+ "time"
-// "github.com/caos/oidc/pkg/oidc"
-// "github.com/caos/oidc/pkg/oidc/rp"
-// "github.com/caos/utils/logging"
-// )
+ "github.com/go-chi/chi/v5"
+ "github.com/sirupsen/logrus"
-// const (
-// publicURL string = "/public"
-// protectedURL string = "/protected"
-// protectedExchangeURL string = "/protected/exchange"
-// )
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+const (
+ publicURL string = "/public"
+ protectedURL string = "/protected"
+ protectedClaimURL string = "/protected/{claim}/{value}"
+)
func main() {
- // clientID := os.Getenv("CLIENT_ID")
- // clientSecret := os.Getenv("CLIENT_SECRET")
- // issuer := os.Getenv("ISSUER")
- // port := os.Getenv("PORT")
+ keyPath := os.Getenv("KEY")
+ port := os.Getenv("PORT")
+ issuer := os.Getenv("ISSUER")
- // // ctx := context.Background()
+ provider, err := rs.NewResourceServerFromKeyFile(context.TODO(), issuer, keyPath)
+ if err != nil {
+ logrus.Fatalf("error creating provider %s", err.Error())
+ }
- // providerConfig := &oidc.ProviderConfig{
- // ClientID: clientID,
- // ClientSecret: clientSecret,
- // Issuer: issuer,
- // }
- // provider, err := rp.NewDefaultProvider(providerConfig)
- // logging.Log("APP-nx6PeF").OnError(err).Panic("error creating provider")
+ router := chi.NewRouter()
- // http.HandleFunc(publicURL, func(w http.ResponseWriter, r *http.Request) {
- // w.Write([]byte("OK"))
- // })
+ // public url accessible without any authorization
+ // will print `OK` and current timestamp
+ router.HandleFunc(publicURL, func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("OK " + time.Now().String()))
+ })
- // http.HandleFunc(protectedURL, func(w http.ResponseWriter, r *http.Request) {
- // ok, token := checkToken(w, r)
- // if !ok {
- // return
- // }
- // resp, err := provider.Introspect(r.Context(), token)
- // if err != nil {
- // http.Error(w, err.Error(), http.StatusForbidden)
- // return
- // }
- // data, err := json.Marshal(resp)
- // if err != nil {
- // http.Error(w, err.Error(), http.StatusInternalServerError)
- // return
- // }
- // w.Write(data)
- // })
+ // protected url which needs an active token
+ // will print the result of the introspection endpoint on success
+ router.HandleFunc(protectedURL, func(w http.ResponseWriter, r *http.Request) {
+ ok, token := checkToken(w, r)
+ if !ok {
+ return
+ }
+ resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusForbidden)
+ return
+ }
+ data, err := json.Marshal(resp)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Write(data)
+ })
- // http.HandleFunc(protectedExchangeURL, func(w http.ResponseWriter, r *http.Request) {
- // ok, token := checkToken(w, r)
- // if !ok {
- // return
- // }
- // tokens, err := provider.DelegationTokenExchange(r.Context(), token, oidc.WithResource([]string{"Test"}))
- // if err != nil {
- // http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
- // return
- // }
+ // protected url which needs an active token and checks if the response of the introspect endpoint
+ // contains a requested claim with the required (string) value
+ // e.g. /protected/username/livio@zitadel.example
+ router.HandleFunc(protectedClaimURL, func(w http.ResponseWriter, r *http.Request) {
+ ok, token := checkToken(w, r)
+ if !ok {
+ return
+ }
+ resp, err := rs.Introspect[*oidc.IntrospectionResponse](r.Context(), provider, token)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusForbidden)
+ return
+ }
+ requestedClaim := chi.URLParam(r, "claim")
+ requestedValue := chi.URLParam(r, "value")
- // data, err := json.Marshal(tokens)
- // if err != nil {
- // http.Error(w, err.Error(), http.StatusInternalServerError)
- // return
- // }
- // w.Write(data)
- // })
+ value, ok := resp.Claims[requestedClaim].(string)
+ if !ok || value == "" || value != requestedValue {
+ http.Error(w, "claim does not match", http.StatusForbidden)
+ return
+ }
+ w.Write([]byte("authorized with value " + value))
+ })
- // lis := fmt.Sprintf("127.0.0.1:%s", port)
- // log.Printf("listening on http://%s/", lis)
- // log.Fatal(http.ListenAndServe(lis, nil))
- // }
-
- // func checkToken(w http.ResponseWriter, r *http.Request) (bool, string) {
- // token := r.Header.Get("authorization")
- // if token == "" {
- // http.Error(w, "Auth header missing", http.StatusUnauthorized)
- // return false, ""
- // }
- // return true, token
+ lis := fmt.Sprintf("127.0.0.1:%s", port)
+ log.Printf("listening on http://%s/", lis)
+ log.Fatal(http.ListenAndServe(lis, router))
+}
+
+func checkToken(w http.ResponseWriter, r *http.Request) (bool, string) {
+ auth := r.Header.Get("authorization")
+ if auth == "" {
+ http.Error(w, "auth header missing", http.StatusUnauthorized)
+ return false, ""
+ }
+ if !strings.HasPrefix(auth, oidc.PrefixBearer) {
+ http.Error(w, "invalid header", http.StatusUnauthorized)
+ return false, ""
+ }
+ return true, strings.TrimPrefix(auth, oidc.PrefixBearer)
}
diff --git a/example/client/app/app.go b/example/client/app/app.go
index f1b99d7..90b1969 100644
--- a/example/client/app/app.go
+++ b/example/client/app/app.go
@@ -4,93 +4,178 @@ import (
"context"
"encoding/json"
"fmt"
+ "log/slog"
"net/http"
"os"
-
- "github.com/sirupsen/logrus"
+ "strings"
+ "sync/atomic"
+ "time"
"github.com/google/uuid"
+ "github.com/sirupsen/logrus"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
- "github.com/caos/oidc/pkg/utils"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/zitadel/logging"
)
var (
- callbackPath string = "/auth/callback"
- key []byte = []byte("test1234test1234")
+ callbackPath = "/auth/callback"
+ key = []byte("test1234test1234")
)
func main() {
clientID := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
+ keyPath := os.Getenv("KEY_PATH")
issuer := os.Getenv("ISSUER")
port := os.Getenv("PORT")
+ scopes := strings.Split(os.Getenv("SCOPES"), " ")
+ responseMode := os.Getenv("RESPONSE_MODE")
- ctx := context.Background()
+ redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
+ cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
- rpConfig := &rp.Config{
- ClientID: clientID,
- ClientSecret: clientSecret,
- Issuer: issuer,
- CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
- Scopes: []string{"openid", "profile", "email"},
+ logger := slog.New(
+ slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ AddSource: true,
+ Level: slog.LevelDebug,
+ }),
+ )
+ client := &http.Client{
+ Timeout: time.Minute,
}
- cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
- provider, err := rp.NewDefaultRP(rpConfig, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //,
+ // enable outgoing request logging
+ logging.EnableHTTPClient(client,
+ logging.WithClientGroup("client"),
+ )
+
+ options := []rp.Option{
+ rp.WithCookieHandler(cookieHandler),
+ rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
+ rp.WithHTTPClient(client),
+ rp.WithLogger(logger),
+ rp.WithSigningAlgsFromDiscovery(),
+ }
+ if clientSecret == "" {
+ options = append(options, rp.WithPKCE(cookieHandler))
+ }
+ if keyPath != "" {
+ options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
+ }
+
+ // One can add a logger to the context,
+ // pre-defining log attributes as required.
+ ctx := logging.ToContext(context.TODO(), logger)
+ provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, redirectURI, scopes, options...)
if err != nil {
logrus.Fatalf("error creating provider %s", err.Error())
}
- // state := "foobar"
- state := uuid.New().String()
+ // generate some state (representing the state of the user in your application,
+ // e.g. the page where he was before sending him to login
+ state := func() string {
+ return uuid.New().String()
+ }
- http.Handle("/login", provider.AuthURLHandler(state))
- // http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
- // http.Redirect(w, r, provider.AuthURL(state), http.StatusFound)
- // })
+ urlOptions := []rp.URLParamOpt{
+ rp.WithPromptURLParam("Welcome back!"),
+ }
- // http.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) {
- // tokens, err := provider.CodeExchange(ctx, r.URL.Query().Get("code"))
- // if err != nil {
- // http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
- // return
- // }
- // data, err := json.Marshal(tokens)
- // if err != nil {
- // http.Error(w, err.Error(), http.StatusInternalServerError)
- // return
- // }
- // w.Write(data)
- // })
+ if responseMode != "" {
+ urlOptions = append(urlOptions, rp.WithResponseModeURLParam(oidc.ResponseMode(responseMode)))
+ }
- marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
- _ = state
- data, err := json.Marshal(tokens)
+ // register the AuthURLHandler at your preferred path.
+ // the AuthURLHandler creates the auth request and redirects the user to the auth server.
+ // including state handling with secure cookie and the possibility to use PKCE.
+ // Prompts can optionally be set to inform the server of
+ // any messages that need to be prompted back to the user.
+ http.Handle("/login", rp.AuthURLHandler(
+ state,
+ provider,
+ urlOptions...,
+ ))
+
+ // for demonstration purposes the returned userinfo response is written as JSON object onto response
+ marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
+ fmt.Println("access token", tokens.AccessToken)
+ fmt.Println("refresh token", tokens.RefreshToken)
+ fmt.Println("id token", tokens.IDToken)
+
+ data, err := json.Marshal(info)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
+ w.Header().Set("content-type", "application/json")
w.Write(data)
}
- http.Handle(callbackPath, provider.CodeExchangeHandler(marshal))
+ // you could also just take the access_token and id_token without calling the userinfo endpoint:
+ //
+ // marshalToken := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
+ // data, err := json.Marshal(tokens)
+ // if err != nil {
+ // http.Error(w, err.Error(), http.StatusInternalServerError)
+ // return
+ // }
+ // w.Write(data)
+ //}
- http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
- tokens, err := provider.ClientCredentials(ctx, "scope")
- if err != nil {
- http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
- return
- }
+ // you can also try token exchange flow
+ //
+ // requestTokenExchange := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {
+ // data := make(url.Values)
+ // data.Set("grant_type", string(oidc.GrantTypeTokenExchange))
+ // data.Set("requested_token_type", string(oidc.IDTokenType))
+ // data.Set("subject_token", tokens.RefreshToken)
+ // data.Set("subject_token_type", string(oidc.RefreshTokenType))
+ // data.Add("scope", "profile custom_scope:impersonate:id2")
+
+ // client := &http.Client{}
+ // r2, _ := http.NewRequest(http.MethodPost, issuer+"/oauth/token", strings.NewReader(data.Encode()))
+ // // r2.Header.Add("Authorization", "Basic "+"d2ViOnNlY3JldA==")
+ // r2.Header.Add("Content-Type", "application/x-www-form-urlencoded")
+ // r2.SetBasicAuth("web", "secret")
+
+ // resp, _ := client.Do(r2)
+ // fmt.Println(resp.Status)
+
+ // b, _ := io.ReadAll(resp.Body)
+ // resp.Body.Close()
+
+ // w.Write(b)
+ // }
+
+ // register the CodeExchangeHandler at the callbackPath
+ // the CodeExchangeHandler handles the auth response, creates the token request and calls the callback function
+ // with the returned tokens from the token endpoint
+ // in this example the callback function itself is wrapped by the UserinfoCallback which
+ // will call the Userinfo endpoint, check the sub and pass the info into the callback function
+ http.Handle(callbackPath, rp.CodeExchangeHandler(rp.UserinfoCallback(marshalUserinfo), provider))
+
+ // if you would use the callback without calling the userinfo endpoint, simply switch the callback handler for:
+ //
+ // http.Handle(callbackPath, rp.CodeExchangeHandler(marshalToken, provider))
+
+ // simple counter for request IDs
+ var counter atomic.Int64
+ // enable incomming request logging
+ mw := logging.Middleware(
+ logging.WithLogger(logger),
+ logging.WithGroup("server"),
+ logging.WithIDFunc(func() slog.Attr {
+ return slog.Int64("id", counter.Add(1))
+ }),
+ )
- data, err := json.Marshal(tokens)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- w.Write(data)
- })
lis := fmt.Sprintf("127.0.0.1:%s", port)
- logrus.Infof("listening on http://%s/", lis)
- logrus.Fatal(http.ListenAndServe("127.0.0.1:5556", nil))
+ logger.Info("server listening, press ctrl+c to stop", "addr", lis)
+ err = http.ListenAndServe(lis, mw(http.DefaultServeMux))
+ if err != http.ErrServerClosed {
+ logger.Error("server terminated", "error", err)
+ os.Exit(1)
+ }
}
diff --git a/example/client/device/device.go b/example/client/device/device.go
new file mode 100644
index 0000000..33bc570
--- /dev/null
+++ b/example/client/device/device.go
@@ -0,0 +1,95 @@
+// Command device is an example Oauth2 Device Authorization Grant app.
+// It creates a new Device Authorization request on the Issuer and then polls for tokens.
+// The user is then prompted to visit a URL and enter the user code.
+// Or, the complete URL can be used instead to omit manual entry.
+// In practice then can be a "magic link" in the form or a QR.
+//
+// The following environment variables are used for configuration:
+//
+// ISSUER: URL to the OP, required.
+// CLIENT_ID: ID of the application, required.
+// CLIENT_SECRET: Secret to authenticate the app using basic auth. Only required if the OP expects this type of authentication.
+// KEY_PATH: Path to a private key file, used to for JWT authentication of the App. Only required if the OP expects this type of authentication.
+// SCOPES: Scopes of the Authentication Request. Optional.
+//
+// Basic usage:
+//
+// cd example/client/device
+// export ISSUER="http://localhost:9000" CLIENT_ID="246048465824634593@demo"
+//
+// Get an Access Token:
+//
+// SCOPES="email profile" go run .
+//
+// Get an Access Token and ID Token:
+//
+// SCOPES="email profile openid" go run .
+//
+// Get an Access Token and Refresh Token
+//
+// SCOPES="email profile offline_access" go run .
+//
+// Get Access, Refresh and ID Tokens:
+//
+// SCOPES="email profile offline_access openid" go run .
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/sirupsen/logrus"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+)
+
+var (
+ key = []byte("test1234test1234")
+)
+
+func main() {
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
+ defer stop()
+
+ clientID := os.Getenv("CLIENT_ID")
+ clientSecret := os.Getenv("CLIENT_SECRET")
+ keyPath := os.Getenv("KEY_PATH")
+ issuer := os.Getenv("ISSUER")
+ scopes := strings.Split(os.Getenv("SCOPES"), " ")
+
+ cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
+
+ var options []rp.Option
+ if clientSecret == "" {
+ options = append(options, rp.WithPKCE(cookieHandler))
+ }
+ if keyPath != "" {
+ options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
+ }
+
+ provider, err := rp.NewRelyingPartyOIDC(ctx, issuer, clientID, clientSecret, "", scopes, options...)
+ if err != nil {
+ logrus.Fatalf("error creating provider %s", err.Error())
+ }
+
+ logrus.Info("starting device authorization flow")
+ resp, err := rp.DeviceAuthorization(ctx, scopes, provider, nil)
+ if err != nil {
+ logrus.Fatal(err)
+ }
+ logrus.Info("resp", resp)
+ fmt.Printf("\nPlease browse to %s and enter code %s\n", resp.VerificationURI, resp.UserCode)
+
+ logrus.Info("start polling")
+ token, err := rp.DeviceAccessToken(ctx, resp.DeviceCode, time.Duration(resp.Interval)*time.Second, provider)
+ if err != nil {
+ logrus.Fatal(err)
+ }
+ logrus.Infof("successfully obtained token: %#v", token)
+}
diff --git a/example/client/github/github.go b/example/client/github/github.go
index 4afa2fb..f6c536b 100644
--- a/example/client/github/github.go
+++ b/example/client/github/github.go
@@ -3,16 +3,22 @@ package main
import (
"context"
"fmt"
- "github.com/caos/oidc/pkg/cli"
- "github.com/caos/oidc/pkg/rp"
- "github.com/google/go-github/v31/github"
- githubOAuth "golang.org/x/oauth2/github"
"os"
+
+ "github.com/google/go-github/v31/github"
+ "github.com/google/uuid"
+ "golang.org/x/oauth2"
+ githubOAuth "golang.org/x/oauth2/github"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp/cli"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
var (
- callbackPath string = "/orbctl/github/callback"
- key []byte = []byte("test1234test1234")
+ callbackPath = "/orbctl/github/callback"
+ key = []byte("test1234test1234")
)
func main() {
@@ -20,24 +26,32 @@ func main() {
clientSecret := os.Getenv("CLIENT_SECRET")
port := os.Getenv("PORT")
- rpConfig := &rp.Config{
+ rpConfig := &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
- CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
+ RedirectURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
Scopes: []string{"repo", "repo_deployment"},
- Endpoints: githubOAuth.Endpoint,
+ Endpoint: githubOAuth.Endpoint,
}
- oauth2Client := cli.CodeFlowForClient(rpConfig, key, callbackPath, port)
-
- client := github.NewClient(oauth2Client)
-
ctx := context.Background()
- _, _, err := client.Users.Get(ctx, "")
+ cookieHandler := http.NewCookieHandler(key, key, http.WithUnsecure())
+ relyingParty, err := rp.NewRelyingPartyOAuth(rpConfig, rp.WithCookieHandler(cookieHandler))
if err != nil {
- fmt.Println("OAuth flow failed")
- } else {
-
- fmt.Println("OAuth flow success")
+ fmt.Printf("error creating relaying party: %v", err)
+ return
}
+ state := func() string {
+ return uuid.New().String()
+ }
+ token := cli.CodeFlow[*oidc.IDTokenClaims](ctx, relyingParty, callbackPath, port, state)
+
+ client := github.NewClient(relyingParty.OAuthConfig().Client(ctx, token.Token))
+
+ _, _, err = client.Users.Get(ctx, "")
+ if err != nil {
+ fmt.Printf("error %v", err)
+ return
+ }
+ fmt.Println("call succeeded")
}
diff --git a/example/client/service/service.go b/example/client/service/service.go
new file mode 100644
index 0000000..a88ab2f
--- /dev/null
+++ b/example/client/service/service.go
@@ -0,0 +1,177 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "io"
+ "net/http"
+ "os"
+ "strings"
+
+ "github.com/sirupsen/logrus"
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/profile"
+)
+
+var client = http.DefaultClient
+
+func main() {
+ keyPath := os.Getenv("KEY_PATH")
+ issuer := os.Getenv("ISSUER")
+ port := os.Getenv("PORT")
+ scopes := strings.Split(os.Getenv("SCOPES"), " ")
+
+ if keyPath != "" {
+ ts, err := profile.NewJWTProfileTokenSourceFromKeyFile(context.TODO(), issuer, keyPath, scopes)
+ if err != nil {
+ logrus.Fatalf("error creating token source %s", err.Error())
+ }
+ client = oauth2.NewClient(context.Background(), ts)
+ }
+
+ http.HandleFunc("/jwt-profile", func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" {
+ tpl := `
+
+
+
+
+ Login
+
+
+
+
+ `
+ t, err := template.New("login").Parse(tpl)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ err = t.Execute(w, nil)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ } else {
+ err := r.ParseMultipartForm(4 << 10)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ file, _, err := r.FormFile("key")
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ defer file.Close()
+
+ key, err := io.ReadAll(file)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ ts, err := profile.NewJWTProfileTokenSourceFromKeyFileData(context.TODO(), issuer, key, scopes)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ client = oauth2.NewClient(context.Background(), ts)
+ token, err := ts.Token()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ data, err := json.Marshal(token)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Write(data)
+ }
+ })
+
+ http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
+ tpl := `
+
+
+
+
+ Test
+
+
+
+ {{if .URL}}
+
+ Result for {{.URL}}: {{.Response}}
+
+ {{end}}
+
+ `
+ err := r.ParseForm()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ testURL := r.Form.Get("url")
+ var data struct {
+ URL string
+ Response any
+ }
+ if testURL != "" {
+ data.URL = testURL
+ data.Response, err = callExampleEndpoint(client, testURL)
+ if err != nil {
+ data.Response = err
+ }
+ }
+ t, err := template.New("login").Parse(tpl)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ err = t.Execute(w, data)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ })
+ lis := fmt.Sprintf("127.0.0.1:%s", port)
+ logrus.Infof("listening on http://%s/", lis)
+ logrus.Fatal(http.ListenAndServe("127.0.0.1:"+port, nil))
+}
+
+func callExampleEndpoint(client *http.Client, testURL string) (any, error) {
+ req, err := http.NewRequest("GET", testURL, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("http status not ok: %s %s", resp.Status, body)
+ }
+
+ if strings.HasPrefix(resp.Header.Get("content-type"), "text/plain") {
+ return string(body), nil
+ }
+ return body, err
+}
diff --git a/example/doc.go b/example/doc.go
index f7ec372..fd4f038 100644
--- a/example/doc.go
+++ b/example/doc.go
@@ -1 +1,10 @@
+/*
+Package example contains some example of the various use of this library:
+
+/api example of an api / resource server implementation using token introspection
+/app web app / RP demonstrating authorization code flow using various authentication methods (code, PKCE, JWT profile)
+/github example of the extended OAuth2 library, providing an HTTP client with a reuse token source
+/service demonstration of JWT Profile Authorization Grant
+/server examples of an OpenID Provider implementations (including dynamic) with some very basic
+*/
package example
diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go
deleted file mode 100644
index 5fb823b..0000000
--- a/example/internal/mock/storage.go
+++ /dev/null
@@ -1,274 +0,0 @@
-package mock
-
-import (
- "context"
- "crypto/rand"
- "crypto/rsa"
- "errors"
- "time"
-
- "gopkg.in/square/go-jose.v2"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/op"
-)
-
-type AuthStorage struct {
- key *rsa.PrivateKey
-}
-
-func NewAuthStorage() op.Storage {
- reader := rand.Reader
- bitSize := 2048
- key, err := rsa.GenerateKey(reader, bitSize)
- if err != nil {
- panic(err)
- }
- return &AuthStorage{
- key: key,
- }
-}
-
-type AuthRequest struct {
- ID string
- ResponseType oidc.ResponseType
- RedirectURI string
- Nonce string
- ClientID string
- CodeChallenge *oidc.CodeChallenge
-}
-
-func (a *AuthRequest) GetACR() string {
- return ""
-}
-
-func (a *AuthRequest) GetAMR() []string {
- return []string{
- "password",
- }
-}
-
-func (a *AuthRequest) GetAudience() []string {
- return []string{
- a.ClientID,
- }
-}
-
-func (a *AuthRequest) GetAuthTime() time.Time {
- return time.Now().UTC()
-}
-
-func (a *AuthRequest) GetClientID() string {
- return a.ClientID
-}
-
-func (a *AuthRequest) GetCode() string {
- return "code"
-}
-
-func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
- return a.CodeChallenge
-}
-
-func (a *AuthRequest) GetID() string {
- return a.ID
-}
-
-func (a *AuthRequest) GetNonce() string {
- return a.Nonce
-}
-
-func (a *AuthRequest) GetRedirectURI() string {
- return a.RedirectURI
- // return "http://localhost:5556/auth/callback"
-}
-
-func (a *AuthRequest) GetResponseType() oidc.ResponseType {
- return a.ResponseType
-}
-
-func (a *AuthRequest) GetScopes() []string {
- return []string{
- "openid",
- "profile",
- "email",
- }
-}
-
-func (a *AuthRequest) GetState() string {
- return ""
-}
-
-func (a *AuthRequest) GetSubject() string {
- return "sub"
-}
-
-func (a *AuthRequest) Done() bool {
- return true
-}
-
-var (
- a = &AuthRequest{}
- t bool
- c string
-)
-
-func (s *AuthStorage) Health(ctx context.Context) error {
- return nil
-}
-
-func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) {
- a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
- if authReq.CodeChallenge != "" {
- a.CodeChallenge = &oidc.CodeChallenge{
- Challenge: authReq.CodeChallenge,
- Method: authReq.CodeChallengeMethod,
- }
- }
- t = false
- return a, nil
-}
-func (s *AuthStorage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) {
- if code != c {
- return nil, errors.New("invalid code")
- }
- return a, nil
-}
-func (s *AuthStorage) SaveAuthCode(_ context.Context, id, code string) error {
- if a.ID != id {
- return errors.New("not found")
- }
- c = code
- return nil
-}
-func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
- t = true
- return nil
-}
-func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
- if id != "id" || t {
- return nil, errors.New("not found")
- }
- return a, nil
-}
-func (s *AuthStorage) CreateToken(_ context.Context, authReq op.AuthRequest) (string, time.Time, error) {
- return authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil
-}
-func (s *AuthStorage) TerminateSession(_ context.Context, userID, clientID string) error {
- return nil
-}
-func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey, _ chan<- error, _ <-chan time.Time) {
- keyCh <- jose.SigningKey{Algorithm: jose.RS256, Key: s.key}
-}
-func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
- return s.key, nil
-}
-func (s *AuthStorage) SaveNewKeyPair(ctx context.Context) error {
- return nil
-}
-func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
- pubkey := s.key.Public()
- return &jose.JSONWebKeySet{
- Keys: []jose.JSONWebKey{
- jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
- },
- }, nil
-}
-
-func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
- if id == "none" {
- return nil, errors.New("not found")
- }
- var appType op.ApplicationType
- var authMethod op.AuthMethod
- var accessTokenType op.AccessTokenType
- if id == "web" {
- appType = op.ApplicationTypeWeb
- authMethod = op.AuthMethodBasic
- accessTokenType = op.AccessTokenTypeBearer
- } else if id == "native" {
- appType = op.ApplicationTypeNative
- authMethod = op.AuthMethodNone
- accessTokenType = op.AccessTokenTypeBearer
- } else {
- appType = op.ApplicationTypeUserAgent
- authMethod = op.AuthMethodNone
- accessTokenType = op.AccessTokenTypeJWT
- }
- return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
-}
-
-func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
- return nil
-}
-
-func (s *AuthStorage) GetUserinfoFromToken(ctx context.Context, _ string) (*oidc.Userinfo, error) {
- return s.GetUserinfoFromScopes(ctx, "", []string{})
-}
-func (s *AuthStorage) GetUserinfoFromScopes(_ context.Context, _ string, _ []string) (*oidc.Userinfo, error) {
- return &oidc.Userinfo{
- Subject: a.GetSubject(),
- Address: &oidc.UserinfoAddress{
- StreetAddress: "Hjkhkj 789\ndsf",
- },
- UserinfoEmail: oidc.UserinfoEmail{
- Email: "test",
- EmailVerified: true,
- },
- UserinfoPhone: oidc.UserinfoPhone{
- PhoneNumber: "sadsa",
- PhoneNumberVerified: true,
- },
- UserinfoProfile: oidc.UserinfoProfile{
- UpdatedAt: time.Now(),
- },
- // Claims: map[string]interface{}{
- // "test": "test",
- // "hkjh": "",
- // },
- }, nil
-}
-
-type ConfClient struct {
- applicationType op.ApplicationType
- authMethod op.AuthMethod
- ID string
- accessTokenType op.AccessTokenType
-}
-
-func (c *ConfClient) GetID() string {
- return c.ID
-}
-func (c *ConfClient) RedirectURIs() []string {
- return []string{
- "https://registered.com/callback",
- "http://localhost:9999/callback",
- "http://localhost:5556/auth/callback",
- "custom://callback",
- "https://localhost:8443/test/a/instructions-example/callback",
- "https://op.certification.openid.net:62064/authz_cb",
- "https://op.certification.openid.net:62064/authz_post",
- }
-}
-func (c *ConfClient) PostLogoutRedirectURIs() []string {
- return []string{}
-}
-
-func (c *ConfClient) LoginURL(id string) string {
- return "login?id=" + id
-}
-
-func (c *ConfClient) ApplicationType() op.ApplicationType {
- return c.applicationType
-}
-
-func (c *ConfClient) GetAuthMethod() op.AuthMethod {
- return c.authMethod
-}
-
-func (c *ConfClient) IDTokenLifetime() time.Duration {
- return time.Duration(5 * time.Minute)
-}
-func (c *ConfClient) AccessTokenType() op.AccessTokenType {
- return c.accessTokenType
-}
diff --git a/example/server/config/config.go b/example/server/config/config.go
new file mode 100644
index 0000000..96837d4
--- /dev/null
+++ b/example/server/config/config.go
@@ -0,0 +1,40 @@
+package config
+
+import (
+ "os"
+ "strings"
+)
+
+const (
+ // default port for the http server to run
+ DefaultIssuerPort = "9998"
+)
+
+type Config struct {
+ Port string
+ RedirectURI []string
+ UsersFile string
+}
+
+// FromEnvVars loads configuration parameters from environment variables.
+// If there is no such variable defined, then use default values.
+func FromEnvVars(defaults *Config) *Config {
+ if defaults == nil {
+ defaults = &Config{}
+ }
+ cfg := &Config{
+ Port: defaults.Port,
+ RedirectURI: defaults.RedirectURI,
+ UsersFile: defaults.UsersFile,
+ }
+ if value, ok := os.LookupEnv("PORT"); ok {
+ cfg.Port = value
+ }
+ if value, ok := os.LookupEnv("USERS_FILE"); ok {
+ cfg.UsersFile = value
+ }
+ if value, ok := os.LookupEnv("REDIRECT_URI"); ok {
+ cfg.RedirectURI = strings.Split(value, ",")
+ }
+ return cfg
+}
diff --git a/example/server/config/config_test.go b/example/server/config/config_test.go
new file mode 100644
index 0000000..3b73c0b
--- /dev/null
+++ b/example/server/config/config_test.go
@@ -0,0 +1,77 @@
+package config
+
+import (
+ "fmt"
+ "os"
+ "testing"
+)
+
+func TestFromEnvVars(t *testing.T) {
+
+ for _, tc := range []struct {
+ name string
+ env map[string]string
+ defaults *Config
+ want *Config
+ }{
+ {
+ name: "no vars, no default values",
+ env: map[string]string{},
+ want: &Config{},
+ },
+ {
+ name: "no vars, only defaults",
+ env: map[string]string{},
+ defaults: &Config{
+ Port: "6666",
+ UsersFile: "/default/user/path",
+ RedirectURI: []string{"re", "direct", "uris"},
+ },
+ want: &Config{
+ Port: "6666",
+ UsersFile: "/default/user/path",
+ RedirectURI: []string{"re", "direct", "uris"},
+ },
+ },
+ {
+ name: "overriding default values",
+ env: map[string]string{
+ "PORT": "1234",
+ "USERS_FILE": "/path/to/users",
+ "REDIRECT_URI": "http://redirect/redirect",
+ },
+ defaults: &Config{
+ Port: "6666",
+ UsersFile: "/default/user/path",
+ RedirectURI: []string{"re", "direct", "uris"},
+ },
+ want: &Config{
+ Port: "1234",
+ UsersFile: "/path/to/users",
+ RedirectURI: []string{"http://redirect/redirect"},
+ },
+ },
+ {
+ name: "multiple redirect uris",
+ env: map[string]string{
+ "REDIRECT_URI": "http://host_1,http://host_2,http://host_3",
+ },
+ want: &Config{
+ RedirectURI: []string{
+ "http://host_1", "http://host_2", "http://host_3",
+ },
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ os.Clearenv()
+ for k, v := range tc.env {
+ os.Setenv(k, v)
+ }
+ cfg := FromEnvVars(tc.defaults)
+ if fmt.Sprint(cfg) != fmt.Sprint(tc.want) {
+ t.Errorf("Expected FromEnvVars()=%q, but got %q", tc.want, cfg)
+ }
+ })
+ }
+}
diff --git a/example/server/default/default.go b/example/server/default/default.go
deleted file mode 100644
index 421c7f7..0000000
--- a/example/server/default/default.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package main
-
-import (
- "context"
- "crypto/sha256"
- "html/template"
- "log"
- "net/http"
-
- "github.com/gorilla/mux"
-
- "github.com/caos/oidc/example/internal/mock"
- "github.com/caos/oidc/pkg/op"
-)
-
-func main() {
- ctx := context.Background()
- port := "9998"
- config := &op.Config{
- Issuer: "http://localhost:9998/",
- CryptoKey: sha256.Sum256([]byte("test")),
- }
- storage := mock.NewAuthStorage()
- handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint(op.NewEndpoint("test")))
- if err != nil {
- log.Fatal(err)
- }
- router := handler.HttpHandler().(*mux.Router)
- router.Methods("GET").Path("/login").HandlerFunc(HandleLogin)
- router.Methods("POST").Path("/login").HandlerFunc(HandleCallback)
- server := &http.Server{
- Addr: ":" + port,
- Handler: router,
- }
- err = server.ListenAndServe()
- if err != nil {
- log.Fatal(err)
- }
- <-ctx.Done()
-}
-
-func HandleLogin(w http.ResponseWriter, r *http.Request) {
- tpl := `
-
-
-
-
- Login
-
-
-
-
- `
- t, err := template.New("login").Parse(tpl)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- err = t.Execute(w, nil)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- }
-}
-
-func HandleCallback(w http.ResponseWriter, r *http.Request) {
- r.ParseForm()
- client := r.FormValue("client")
- http.Redirect(w, r, "/authorize/"+client, http.StatusFound)
-}
diff --git a/example/server/dynamic/login.go b/example/server/dynamic/login.go
new file mode 100644
index 0000000..05f0e34
--- /dev/null
+++ b/example/server/dynamic/login.go
@@ -0,0 +1,113 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "html/template"
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+const (
+ queryAuthRequestID = "authRequestID"
+)
+
+var (
+ loginTmpl, _ = template.New("login").Parse(`
+
+
+
+
+ Login
+
+
+
+
+ `)
+)
+
+type login struct {
+ authenticate authenticate
+ router chi.Router
+ callback func(context.Context, string) string
+}
+
+func NewLogin(authenticate authenticate, callback func(context.Context, string) string, issuerInterceptor *op.IssuerInterceptor) *login {
+ l := &login{
+ authenticate: authenticate,
+ callback: callback,
+ }
+ l.createRouter(issuerInterceptor)
+ return l
+}
+
+func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
+ l.router = chi.NewRouter()
+ l.router.Get("/username", l.loginHandler)
+ l.router.With(issuerInterceptor.Handler).Post("/username", l.checkLoginHandler)
+}
+
+type authenticate interface {
+ CheckUsernamePassword(ctx context.Context, username, password, id string) error
+}
+
+func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("cannot parse form:%s", err), http.StatusInternalServerError)
+ return
+ }
+ //the oidc package will pass the id of the auth request as query parameter
+ //we will use this id through the login process and therefore pass it to the login page
+ renderLogin(w, r.FormValue(queryAuthRequestID), nil)
+}
+
+func renderLogin(w http.ResponseWriter, id string, err error) {
+ var errMsg string
+ if err != nil {
+ errMsg = err.Error()
+ }
+ data := &struct {
+ ID string
+ Error string
+ }{
+ ID: id,
+ Error: errMsg,
+ }
+ err = loginTmpl.Execute(w, data)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+func (l *login) checkLoginHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("cannot parse form:%s", err), http.StatusInternalServerError)
+ return
+ }
+ username := r.FormValue("username")
+ password := r.FormValue("password")
+ id := r.FormValue("id")
+ err = l.authenticate.CheckUsernamePassword(r.Context(), username, password, id)
+ if err != nil {
+ renderLogin(w, id, err)
+ return
+ }
+ http.Redirect(w, r, l.callback(r.Context(), id), http.StatusFound)
+}
diff --git a/example/server/dynamic/op.go b/example/server/dynamic/op.go
new file mode 100644
index 0000000..2c00e41
--- /dev/null
+++ b/example/server/dynamic/op.go
@@ -0,0 +1,138 @@
+package main
+
+import (
+ "context"
+ "crypto/sha256"
+ "fmt"
+ "log"
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+ "golang.org/x/text/language"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+const (
+ pathLoggedOut = "/logged-out"
+)
+
+var (
+ hostnames = []string{
+ "localhost", //note that calling 127.0.0.1 / ::1 won't work as the hostname does not match
+ "oidc.local", //add this to your hosts file (pointing to 127.0.0.1)
+ //feel free to add more...
+ }
+)
+
+func init() {
+ storage.RegisterClients(
+ storage.NativeClient("native"),
+ storage.WebClient("web", "secret"),
+ storage.WebClient("api", "secret"),
+ )
+}
+
+func main() {
+ ctx := context.Background()
+
+ port := "9998"
+ issuers := make([]string, len(hostnames))
+ for i, hostname := range hostnames {
+ issuers[i] = fmt.Sprintf("http://%s:%s/", hostname, port)
+ }
+
+ //the OpenID Provider requires a 32-byte key for (token) encryption
+ //be sure to create a proper crypto random key and manage it securely!
+ key := sha256.Sum256([]byte("test"))
+
+ router := chi.NewRouter()
+
+ //for simplicity, we provide a very small default page for users who have signed out
+ router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
+ _, err := w.Write([]byte("signed out successfully"))
+ if err != nil {
+ log.Printf("error serving logged out page: %v", err)
+ }
+ })
+
+ //the OpenIDProvider interface needs a Storage interface handling various checks and state manipulations
+ //this might be the layer for accessing your database
+ //in this example it will be handled in-memory
+ //the NewMultiStorage is able to handle multiple issuers
+ storage := storage.NewMultiStorage(issuers)
+
+ //creation of the OpenIDProvider with the just created in-memory Storage
+ provider, err := newDynamicOP(ctx, storage, key)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ //the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process
+ //for the simplicity of the example this means a simple page with username and password field
+ //be sure to provide an IssuerInterceptor with the IssuerFromRequest from the OP so the login can select / and pass it to the storage
+ l := NewLogin(storage, op.AuthCallbackURL(provider), op.NewIssuerInterceptor(provider.IssuerFromRequest))
+
+ //regardless of how many pages / steps there are in the process, the UI must be registered in the router,
+ //so we will direct all calls to /login to the login UI
+ router.Mount("/login/", http.StripPrefix("/login", l.router))
+
+ //we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
+ //is served on the correct path
+ //
+ //if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
+ //then you would have to set the path prefix (/custom/path/):
+ //router.PathPrefix("/custom/path/").Handler(http.StripPrefix("/custom/path", provider.HttpHandler()))
+ router.Mount("/", provider)
+
+ server := &http.Server{
+ Addr: ":" + port,
+ Handler: router,
+ }
+ err = server.ListenAndServe()
+ if err != nil {
+ log.Fatal(err)
+ }
+ <-ctx.Done()
+}
+
+// newDynamicOP will create an OpenID Provider for localhost on a specified port with a given encryption key
+// and a predefined default logout uri
+// it will enable all options (see descriptions)
+func newDynamicOP(ctx context.Context, storage op.Storage, key [32]byte) (*op.Provider, error) {
+ config := &op.Config{
+ CryptoKey: key,
+
+ //will be used if the end_session endpoint is called without a post_logout_redirect_uri
+ DefaultLogoutRedirectURI: pathLoggedOut,
+
+ //enables code_challenge_method S256 for PKCE (and therefore PKCE in general)
+ CodeMethodS256: true,
+
+ //enables additional client_id/client_secret authentication by form post (not only HTTP Basic Auth)
+ AuthMethodPost: true,
+
+ //enables additional authentication by using private_key_jwt
+ AuthMethodPrivateKeyJWT: true,
+
+ //enables refresh_token grant use
+ GrantTypeRefreshToken: true,
+
+ //enables use of the `request` Object parameter
+ RequestObjectSupported: true,
+
+ //this example has only static texts (in English), so we'll set the here accordingly
+ SupportedUILocales: []language.Tag{language.English},
+ }
+ handler, err := op.NewDynamicOpenIDProvider("/", config, storage,
+ //we must explicitly allow the use of the http issuer
+ op.WithAllowInsecure(),
+ //as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
+ op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
+ )
+ if err != nil {
+ return nil, err
+ }
+ return handler, nil
+}
diff --git a/example/server/exampleop/device.go b/example/server/exampleop/device.go
new file mode 100644
index 0000000..99505e4
--- /dev/null
+++ b/example/server/exampleop/device.go
@@ -0,0 +1,204 @@
+package exampleop
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/go-chi/chi/v5"
+ "github.com/gorilla/securecookie"
+ "github.com/sirupsen/logrus"
+)
+
+type deviceAuthenticate interface {
+ CheckUsernamePasswordSimple(username, password string) error
+ op.DeviceAuthorizationStorage
+
+ // GetDeviceAuthorizationByUserCode resturns the current state of the device authorization flow,
+ // identified by the user code.
+ GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error)
+
+ // CompleteDeviceAuthorization marks a device authorization entry as Completed,
+ // identified by userCode. The Subject is added to the state, so that
+ // GetDeviceAuthorizatonState can use it to create a new Access Token.
+ CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error
+
+ // DenyDeviceAuthorization marks a device authorization entry as Denied.
+ DenyDeviceAuthorization(ctx context.Context, userCode string) error
+}
+
+type deviceLogin struct {
+ storage deviceAuthenticate
+ cookie *securecookie.SecureCookie
+}
+
+func registerDeviceAuth(storage deviceAuthenticate, router chi.Router) {
+ l := &deviceLogin{
+ storage: storage,
+ cookie: securecookie.New(securecookie.GenerateRandomKey(32), nil),
+ }
+
+ router.HandleFunc("/", l.userCodeHandler)
+ router.Post("/login", l.loginHandler)
+ router.HandleFunc("/confirm", l.confirmHandler)
+}
+
+func renderUserCode(w io.Writer, err error) {
+ data := struct {
+ Error string
+ }{
+ Error: errMsg(err),
+ }
+
+ if err := templates.ExecuteTemplate(w, "usercode", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func renderDeviceLogin(w http.ResponseWriter, userCode string, err error) {
+ data := &struct {
+ UserCode string
+ Error string
+ }{
+ UserCode: userCode,
+ Error: errMsg(err),
+ }
+ if err = templates.ExecuteTemplate(w, "device_login", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func renderConfirmPage(w http.ResponseWriter, username, clientID string, scopes []string) {
+ data := &struct {
+ Username string
+ ClientID string
+ Scopes []string
+ }{
+ Username: username,
+ ClientID: clientID,
+ Scopes: scopes,
+ }
+ if err := templates.ExecuteTemplate(w, "confirm_device", data); err != nil {
+ logrus.Error(err)
+ }
+}
+
+func (d *deviceLogin) userCodeHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ renderUserCode(w, err)
+ return
+ }
+ userCode := r.Form.Get("user_code")
+ if userCode == "" {
+ if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" {
+ err = errors.New(prompt)
+ }
+ renderUserCode(w, err)
+ return
+ }
+
+ renderDeviceLogin(w, userCode, nil)
+}
+
+func redirectBack(w http.ResponseWriter, r *http.Request, prompt string) {
+ values := make(url.Values)
+ values.Set("prompt", url.QueryEscape(prompt))
+
+ url := url.URL{
+ Path: "/device",
+ RawQuery: values.Encode(),
+ }
+ http.Redirect(w, r, url.String(), http.StatusSeeOther)
+}
+
+const userCodeCookieName = "user_code"
+
+type userCodeCookie struct {
+ UserCode string
+ UserName string
+}
+
+func (d *deviceLogin) loginHandler(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ userCode := r.PostForm.Get("user_code")
+ if userCode == "" {
+ redirectBack(w, r, "missing user_code in request")
+ return
+ }
+ username := r.PostForm.Get("username")
+ if username == "" {
+ redirectBack(w, r, "missing username in request")
+ return
+ }
+ password := r.PostForm.Get("password")
+ if password == "" {
+ redirectBack(w, r, "missing password in request")
+ return
+ }
+
+ if err := d.storage.CheckUsernamePasswordSimple(username, password); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ state, err := d.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ encoded, err := d.cookie.Encode(userCodeCookieName, userCodeCookie{userCode, username})
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ cookie := &http.Cookie{
+ Name: userCodeCookieName,
+ Value: encoded,
+ Path: "/",
+ }
+ http.SetCookie(w, cookie)
+ renderConfirmPage(w, username, state.ClientID, state.Scopes)
+}
+
+func (d *deviceLogin) confirmHandler(w http.ResponseWriter, r *http.Request) {
+ cookie, err := r.Cookie(userCodeCookieName)
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ data := new(userCodeCookie)
+ if err = d.cookie.Decode(userCodeCookieName, cookie.Value, &data); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+ if err = r.ParseForm(); err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ action := r.Form.Get("action")
+ switch action {
+ case "allowed":
+ err = d.storage.CompleteDeviceAuthorization(r.Context(), data.UserCode, data.UserName)
+ case "denied":
+ err = d.storage.DenyDeviceAuthorization(r.Context(), data.UserCode)
+ default:
+ err = errors.New("action must be one of \"allow\" or \"deny\"")
+ }
+ if err != nil {
+ redirectBack(w, r, err.Error())
+ return
+ }
+
+ fmt.Fprintf(w, "Device authorization %s. You can now return to the device", action)
+}
diff --git a/example/server/exampleop/login.go b/example/server/exampleop/login.go
new file mode 100644
index 0000000..77a6189
--- /dev/null
+++ b/example/server/exampleop/login.go
@@ -0,0 +1,77 @@
+package exampleop
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/go-chi/chi/v5"
+)
+
+type login struct {
+ authenticate authenticate
+ router chi.Router
+ callback func(context.Context, string) string
+}
+
+func NewLogin(authenticate authenticate, callback func(context.Context, string) string, issuerInterceptor *op.IssuerInterceptor) *login {
+ l := &login{
+ authenticate: authenticate,
+ callback: callback,
+ }
+ l.createRouter(issuerInterceptor)
+ return l
+}
+
+func (l *login) createRouter(issuerInterceptor *op.IssuerInterceptor) {
+ l.router = chi.NewRouter()
+ l.router.Get("/username", l.loginHandler)
+ l.router.Post("/username", issuerInterceptor.HandlerFunc(l.checkLoginHandler))
+}
+
+type authenticate interface {
+ CheckUsernamePassword(username, password, id string) error
+}
+
+func (l *login) loginHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("cannot parse form:%s", err), http.StatusInternalServerError)
+ return
+ }
+ // the oidc package will pass the id of the auth request as query parameter
+ // we will use this id through the login process and therefore pass it to the login page
+ renderLogin(w, r.FormValue(queryAuthRequestID), nil)
+}
+
+func renderLogin(w http.ResponseWriter, id string, err error) {
+ data := &struct {
+ ID string
+ Error string
+ }{
+ ID: id,
+ Error: errMsg(err),
+ }
+ err = templates.ExecuteTemplate(w, "login", data)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+func (l *login) checkLoginHandler(w http.ResponseWriter, r *http.Request) {
+ err := r.ParseForm()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("cannot parse form:%s", err), http.StatusInternalServerError)
+ return
+ }
+ username := r.FormValue("username")
+ password := r.FormValue("password")
+ id := r.FormValue("id")
+ err = l.authenticate.CheckUsernamePassword(username, password, id)
+ if err != nil {
+ renderLogin(w, id, err)
+ return
+ }
+ http.Redirect(w, r, l.callback(r.Context(), id), http.StatusFound)
+}
diff --git a/example/server/exampleop/op.go b/example/server/exampleop/op.go
new file mode 100644
index 0000000..e12c755
--- /dev/null
+++ b/example/server/exampleop/op.go
@@ -0,0 +1,136 @@
+package exampleop
+
+import (
+ "crypto/sha256"
+ "log"
+ "log/slog"
+ "net/http"
+ "sync/atomic"
+ "time"
+
+ "github.com/go-chi/chi/v5"
+ "github.com/zitadel/logging"
+ "golang.org/x/text/language"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+const (
+ pathLoggedOut = "/logged-out"
+)
+
+type Storage interface {
+ op.Storage
+ authenticate
+ deviceAuthenticate
+}
+
+// simple counter for request IDs
+var counter atomic.Int64
+
+// SetupServer creates an OIDC server with Issuer=http://localhost:
+//
+// Use one of the pre-made clients in storage/clients.go or register a new one.
+func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer bool, extraOptions ...op.Option) chi.Router {
+ // the OpenID Provider requires a 32-byte key for (token) encryption
+ // be sure to create a proper crypto random key and manage it securely!
+ key := sha256.Sum256([]byte("test"))
+
+ router := chi.NewRouter()
+ router.Use(logging.Middleware(
+ logging.WithLogger(logger),
+ logging.WithIDFunc(func() slog.Attr {
+ return slog.Int64("id", counter.Add(1))
+ }),
+ ))
+
+ // for simplicity, we provide a very small default page for users who have signed out
+ router.HandleFunc(pathLoggedOut, func(w http.ResponseWriter, req *http.Request) {
+ w.Write([]byte("signed out successfully"))
+ // no need to check/log error, this will be handled by the middleware.
+ })
+
+ // creation of the OpenIDProvider with the just created in-memory Storage
+ provider, err := newOP(storage, issuer, key, logger, extraOptions...)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ //the provider will only take care of the OpenID Protocol, so there must be some sort of UI for the login process
+ //for the simplicity of the example this means a simple page with username and password field
+ //be sure to provide an IssuerInterceptor with the IssuerFromRequest from the OP so the login can select / and pass it to the storage
+ l := NewLogin(storage, op.AuthCallbackURL(provider), op.NewIssuerInterceptor(provider.IssuerFromRequest))
+
+ // regardless of how many pages / steps there are in the process, the UI must be registered in the router,
+ // so we will direct all calls to /login to the login UI
+ router.Mount("/login/", http.StripPrefix("/login", l.router))
+
+ router.Route("/device", func(r chi.Router) {
+ registerDeviceAuth(storage, r)
+ })
+
+ handler := http.Handler(provider)
+ if wrapServer {
+ handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints), op.AuthorizeCallbackHandler(provider))
+ }
+
+ // we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
+ // is served on the correct path
+ //
+ // if your issuer ends with a path (e.g. http://localhost:9998/custom/path/),
+ // then you would have to set the path prefix (/custom/path/)
+ router.Mount("/", handler)
+
+ return router
+}
+
+// newOP will create an OpenID Provider for localhost on a specified port with a given encryption key
+// and a predefined default logout uri
+// it will enable all options (see descriptions)
+func newOP(storage op.Storage, issuer string, key [32]byte, logger *slog.Logger, extraOptions ...op.Option) (op.OpenIDProvider, error) {
+ config := &op.Config{
+ CryptoKey: key,
+
+ // will be used if the end_session endpoint is called without a post_logout_redirect_uri
+ DefaultLogoutRedirectURI: pathLoggedOut,
+
+ // enables code_challenge_method S256 for PKCE (and therefore PKCE in general)
+ CodeMethodS256: true,
+
+ // enables additional client_id/client_secret authentication by form post (not only HTTP Basic Auth)
+ AuthMethodPost: true,
+
+ // enables additional authentication by using private_key_jwt
+ AuthMethodPrivateKeyJWT: true,
+
+ // enables refresh_token grant use
+ GrantTypeRefreshToken: true,
+
+ // enables use of the `request` Object parameter
+ RequestObjectSupported: true,
+
+ // this example has only static texts (in English), so we'll set the here accordingly
+ SupportedUILocales: []language.Tag{language.English},
+
+ DeviceAuthorization: op.DeviceAuthorizationConfig{
+ Lifetime: 5 * time.Minute,
+ PollInterval: 5 * time.Second,
+ UserFormPath: "/device",
+ UserCode: op.UserCodeBase20,
+ },
+ }
+ handler, err := op.NewOpenIDProvider(issuer, config, storage,
+ append([]op.Option{
+ //we must explicitly allow the use of the http issuer
+ op.WithAllowInsecure(),
+ // as an example on how to customize an endpoint this will change the authorization_endpoint from /authorize to /auth
+ op.WithCustomAuthEndpoint(op.NewEndpoint("auth")),
+ // Pass our logger to the OP
+ op.WithLogger(logger.WithGroup("op")),
+ }, extraOptions...)...,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return handler, nil
+}
diff --git a/example/server/exampleop/templates.go b/example/server/exampleop/templates.go
new file mode 100644
index 0000000..5b5c966
--- /dev/null
+++ b/example/server/exampleop/templates.go
@@ -0,0 +1,26 @@
+package exampleop
+
+import (
+ "embed"
+ "html/template"
+
+ "github.com/sirupsen/logrus"
+)
+
+var (
+ //go:embed templates
+ templateFS embed.FS
+ templates = template.Must(template.ParseFS(templateFS, "templates/*.html"))
+)
+
+const (
+ queryAuthRequestID = "authRequestID"
+)
+
+func errMsg(err error) string {
+ if err == nil {
+ return ""
+ }
+ logrus.Error(err)
+ return err.Error()
+}
diff --git a/example/server/exampleop/templates/confirm_device.html b/example/server/exampleop/templates/confirm_device.html
new file mode 100644
index 0000000..a6bcdad
--- /dev/null
+++ b/example/server/exampleop/templates/confirm_device.html
@@ -0,0 +1,25 @@
+{{ define "confirm_device" -}}
+
+
+
+
+ Confirm device authorization
+
+
+
+ Welcome back {{.Username}}!
+
+ You are about to grant device {{.ClientID}} access to the following scopes: {{.Scopes}}.
+
+ Allow
+ Deny
+
+
+{{- end }}
diff --git a/example/server/exampleop/templates/device_login.html b/example/server/exampleop/templates/device_login.html
new file mode 100644
index 0000000..cc5b00b
--- /dev/null
+++ b/example/server/exampleop/templates/device_login.html
@@ -0,0 +1,29 @@
+{{ define "device_login" -}}
+
+
+
+
+ Login
+
+
+
+
+
+{{- end }}
diff --git a/example/server/exampleop/templates/login.html b/example/server/exampleop/templates/login.html
new file mode 100644
index 0000000..b048211
--- /dev/null
+++ b/example/server/exampleop/templates/login.html
@@ -0,0 +1,29 @@
+{{ define "login" -}}
+
+
+
+
+ Login
+
+
+
+
+`
+{{- end }}
\ No newline at end of file
diff --git a/example/server/exampleop/templates/usercode.html b/example/server/exampleop/templates/usercode.html
new file mode 100644
index 0000000..fb8fa7f
--- /dev/null
+++ b/example/server/exampleop/templates/usercode.html
@@ -0,0 +1,21 @@
+{{ define "usercode" -}}
+
+
+
+
+ Device authorization
+
+
+
+
+
+{{- end }}
diff --git a/example/server/main.go b/example/server/main.go
new file mode 100644
index 0000000..5bdbb05
--- /dev/null
+++ b/example/server/main.go
@@ -0,0 +1,59 @@
+package main
+
+import (
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/config"
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/exampleop"
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+)
+
+func getUserStore(cfg *config.Config) (storage.UserStore, error) {
+ if cfg.UsersFile == "" {
+ return storage.NewUserStore(fmt.Sprintf("http://localhost:%s/", cfg.Port)), nil
+ }
+ return storage.StoreFromFile(cfg.UsersFile)
+}
+
+func main() {
+ cfg := config.FromEnvVars(&config.Config{Port: "9998"})
+ logger := slog.New(
+ slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ AddSource: true,
+ Level: slog.LevelDebug,
+ }),
+ )
+
+ //which gives us the issuer: http://localhost:9998/
+ issuer := fmt.Sprintf("http://localhost:%s/", cfg.Port)
+
+ storage.RegisterClients(
+ storage.NativeClient("native", cfg.RedirectURI...),
+ storage.WebClient("web", "secret", cfg.RedirectURI...),
+ storage.WebClient("api", "secret", cfg.RedirectURI...),
+ )
+
+ // the OpenIDProvider interface needs a Storage interface handling various checks and state manipulations
+ // this might be the layer for accessing your database
+ // in this example it will be handled in-memory
+ store, err := getUserStore(cfg)
+ if err != nil {
+ logger.Error("cannot create UserStore", "error", err)
+ os.Exit(1)
+ }
+ storage := storage.NewStorage(store)
+ router := exampleop.SetupServer(issuer, storage, logger, false)
+
+ server := &http.Server{
+ Addr: ":" + cfg.Port,
+ Handler: router,
+ }
+ logger.Info("server listening, press ctrl+c to stop", "addr", issuer)
+ if server.ListenAndServe() != http.ErrServerClosed {
+ logger.Error("server terminated", "error", err)
+ os.Exit(1)
+ }
+}
diff --git a/example/server/service-key1.json b/example/server/service-key1.json
new file mode 100644
index 0000000..a0d20e8
--- /dev/null
+++ b/example/server/service-key1.json
@@ -0,0 +1 @@
+{"type":"serviceaccount","keyId":"key1","key":"-----BEGIN RSA PRIVATE KEY-----\nMIICXgIBAAKBgQD21E+180rCAzp15zy2X/JOYYHtxYhF51pWCsITeChJd7sFWxp1\ntxSHTiomQYBiBWgcCavsdu/VLPQJhO3PTIyglxc1XRGsM48oDT5MkFsAVDvbjuWk\nF0lstQyw4pr8Wg0Ecf1aL6YlvVKB9h5rAgZ9T+elNJ7q5takMAvNhu7zMQIDAQAB\nAoGAeLRw2qjEaUZM43WWchVPmFcEw/MyZgTyX1tZd03uXacolUDtGp3ScyydXiHw\nF39PX063fabYOCaInNMdvJ9RsQz2OcZuS/K6NOmWhzBfLgs4Y1tU6ijoY/gBjHgu\nCV0KjvoWIfEtKl/On/wTrAnUStFzrc7U4dpKFP1fy2ZTTnECQQD8aP2QOxmKUyfg\nBAjfonpkrNeaTRNwTULTvEHFiLyaeFd1PAvsDiKZtpk6iHLb99mQZkVVtAK5qgQ4\n1OI72jkVAkEA+lcAamuZAM+gIiUhbHA7BfX9OVgyGDD2tx5g/kxhMUmK6hIiO6Ul\n0nw5KfrCEUU3AzrM7HejUg3q61SYcXTgrQJBALhrzbhwNf0HPP9Ec2dSw7KDRxSK\ndEV9bfJefn/hpEwI2X3i3aMfwNAmxlYqFCH8OY5z6vzvhX46ZtNPV+z7SPECQQDq\nApXi5P27YlpgULEzup2R7uZsymLZdjvJ5V3pmOBpwENYlublNnVqkrCk60CqADdy\nj26rxRIoS9ZDcWqm9AhpAkEAyrNXBMJh08ghBMb3NYPFfr/bftRJSrGjhBPuJ5qr\nXzWaXhYVMMh3OSAwzHBJbA1ffdQJuH2ebL99Ur5fpBcbVw==\n-----END RSA PRIVATE KEY-----\n","userId":"service"}
diff --git a/example/server/storage/client.go b/example/server/storage/client.go
new file mode 100644
index 0000000..2b836c0
--- /dev/null
+++ b/example/server/storage/client.go
@@ -0,0 +1,235 @@
+package storage
+
+import (
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+var (
+ // we use the default login UI and pass the (auth request) id
+ defaultLoginURL = func(id string) string {
+ return "/login/username?authRequestID=" + id
+ }
+
+ // clients to be used by the storage interface
+ clients = map[string]*Client{}
+)
+
+// Client represents the storage model of an OAuth/OIDC client
+// this could also be your database model
+type Client struct {
+ id string
+ secret string
+ redirectURIs []string
+ applicationType op.ApplicationType
+ authMethod oidc.AuthMethod
+ loginURL func(string) string
+ responseTypes []oidc.ResponseType
+ grantTypes []oidc.GrantType
+ accessTokenType op.AccessTokenType
+ devMode bool
+ idTokenUserinfoClaimsAssertion bool
+ clockSkew time.Duration
+ postLogoutRedirectURIGlobs []string
+ redirectURIGlobs []string
+}
+
+// GetID must return the client_id
+func (c *Client) GetID() string {
+ return c.id
+}
+
+// RedirectURIs must return the registered redirect_uris for Code and Implicit Flow
+func (c *Client) RedirectURIs() []string {
+ return c.redirectURIs
+}
+
+// PostLogoutRedirectURIs must return the registered post_logout_redirect_uris for sign-outs
+func (c *Client) PostLogoutRedirectURIs() []string {
+ return []string{}
+}
+
+// ApplicationType must return the type of the client (app, native, user agent)
+func (c *Client) ApplicationType() op.ApplicationType {
+ return c.applicationType
+}
+
+// AuthMethod must return the authentication method (client_secret_basic, client_secret_post, none, private_key_jwt)
+func (c *Client) AuthMethod() oidc.AuthMethod {
+ return c.authMethod
+}
+
+// ResponseTypes must return all allowed response types (code, id_token token, id_token)
+// these must match with the allowed grant types
+func (c *Client) ResponseTypes() []oidc.ResponseType {
+ return c.responseTypes
+}
+
+// GrantTypes must return all allowed grant types (authorization_code, refresh_token, urn:ietf:params:oauth:grant-type:jwt-bearer)
+func (c *Client) GrantTypes() []oidc.GrantType {
+ return c.grantTypes
+}
+
+// LoginURL will be called to redirect the user (agent) to the login UI
+// you could implement some logic here to redirect the users to different login UIs depending on the client
+func (c *Client) LoginURL(id string) string {
+ return c.loginURL(id)
+}
+
+// AccessTokenType must return the type of access token the client uses (Bearer (opaque) or JWT)
+func (c *Client) AccessTokenType() op.AccessTokenType {
+ return c.accessTokenType
+}
+
+// IDTokenLifetime must return the lifetime of the client's id_tokens
+func (c *Client) IDTokenLifetime() time.Duration {
+ return 1 * time.Hour
+}
+
+// DevMode enables the use of non-compliant configs such as redirect_uris (e.g. http schema for user agent client)
+func (c *Client) DevMode() bool {
+ return c.devMode
+}
+
+// RestrictAdditionalIdTokenScopes allows specifying which custom scopes shall be asserted into the id_token
+func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+// RestrictAdditionalAccessTokenScopes allows specifying which custom scopes shall be asserted into the JWT access_token
+func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+// IsScopeAllowed enables Client specific custom scopes validation
+// in this example we allow the CustomScope for all clients
+func (c *Client) IsScopeAllowed(scope string) bool {
+ return scope == CustomScope
+}
+
+// IDTokenUserinfoClaimsAssertion allows specifying if claims of scope profile, email, phone and address are asserted into the id_token
+// even if an access token if issued which violates the OIDC Core spec
+// (5.4. Requesting Claims using Scope Values: https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims)
+// some clients though require that e.g. email is always in the id_token when requested even if an access_token is issued
+func (c *Client) IDTokenUserinfoClaimsAssertion() bool {
+ return c.idTokenUserinfoClaimsAssertion
+}
+
+// ClockSkew enables clients to instruct the OP to apply a clock skew on the various times and expirations
+// (subtract from issued_at, add to expiration, ...)
+func (c *Client) ClockSkew() time.Duration {
+ return c.clockSkew
+}
+
+// RegisterClients enables you to register clients for the example implementation
+// there are some clients (web and native) to try out different cases
+// add more if necessary
+//
+// RegisterClients should be called before the Storage is used so that there are
+// no race conditions.
+func RegisterClients(registerClients ...*Client) {
+ for _, client := range registerClients {
+ clients[client.id] = client
+ }
+}
+
+// NativeClient will create a client of type native, which will always use PKCE and allow the use of refresh tokens
+// user-defined redirectURIs may include:
+// - http://localhost without port specification (e.g. http://localhost/auth/callback)
+// - custom protocol (e.g. custom://auth/callback)
+// (the examples will be used as default, if none is provided)
+func NativeClient(id string, redirectURIs ...string) *Client {
+ if len(redirectURIs) == 0 {
+ redirectURIs = []string{
+ "http://localhost/auth/callback",
+ "custom://auth/callback",
+ }
+ }
+ return &Client{
+ id: id,
+ secret: "", // no secret needed (due to PKCE)
+ redirectURIs: redirectURIs,
+ applicationType: op.ApplicationTypeNative,
+ authMethod: oidc.AuthMethodNone,
+ loginURL: defaultLoginURL,
+ responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
+ grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken},
+ accessTokenType: op.AccessTokenTypeBearer,
+ devMode: false,
+ idTokenUserinfoClaimsAssertion: false,
+ clockSkew: 0,
+ }
+}
+
+// WebClient will create a client of type web, which will always use Basic Auth and allow the use of refresh tokens
+// user-defined redirectURIs may include:
+// - http://localhost with port specification (e.g. http://localhost:9999/auth/callback)
+// (the example will be used as default, if none is provided)
+func WebClient(id, secret string, redirectURIs ...string) *Client {
+ if len(redirectURIs) == 0 {
+ redirectURIs = []string{
+ "http://localhost:9999/auth/callback",
+ }
+ }
+ return &Client{
+ id: id,
+ secret: secret,
+ redirectURIs: redirectURIs,
+ applicationType: op.ApplicationTypeWeb,
+ authMethod: oidc.AuthMethodBasic,
+ loginURL: defaultLoginURL,
+ responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode, oidc.ResponseTypeIDTokenOnly, oidc.ResponseTypeIDToken},
+ grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken, oidc.GrantTypeTokenExchange},
+ accessTokenType: op.AccessTokenTypeBearer,
+ devMode: true,
+ idTokenUserinfoClaimsAssertion: false,
+ clockSkew: 0,
+ }
+}
+
+// DeviceClient creates a device client with Basic authentication.
+func DeviceClient(id, secret string) *Client {
+ return &Client{
+ id: id,
+ secret: secret,
+ redirectURIs: nil,
+ applicationType: op.ApplicationTypeWeb,
+ authMethod: oidc.AuthMethodBasic,
+ loginURL: defaultLoginURL,
+ responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode},
+ grantTypes: []oidc.GrantType{oidc.GrantTypeDeviceCode},
+ accessTokenType: op.AccessTokenTypeBearer,
+ devMode: false,
+ idTokenUserinfoClaimsAssertion: false,
+ clockSkew: 0,
+ }
+}
+
+type hasRedirectGlobs struct {
+ *Client
+}
+
+// RedirectURIGlobs provide wildcarding for additional valid redirects
+func (c hasRedirectGlobs) RedirectURIGlobs() []string {
+ return c.redirectURIGlobs
+}
+
+// PostLogoutRedirectURIGlobs provide extra wildcarding for additional valid redirects
+func (c hasRedirectGlobs) PostLogoutRedirectURIGlobs() []string {
+ return c.postLogoutRedirectURIGlobs
+}
+
+// RedirectGlobsClient wraps the client in a op.HasRedirectGlobs
+// only if DevMode is enabled.
+func RedirectGlobsClient(client *Client) op.Client {
+ if client.devMode {
+ return hasRedirectGlobs{client}
+ }
+ return client
+}
diff --git a/example/server/storage/oidc.go b/example/server/storage/oidc.go
new file mode 100644
index 0000000..9c7f544
--- /dev/null
+++ b/example/server/storage/oidc.go
@@ -0,0 +1,230 @@
+package storage
+
+import (
+ "log/slog"
+ "time"
+
+ "golang.org/x/text/language"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+const (
+ // CustomScope is an example for how to use custom scopes in this library
+ //(in this scenario, when requested, it will return a custom claim)
+ CustomScope = "custom_scope"
+
+ // CustomClaim is an example for how to return custom claims with this library
+ CustomClaim = "custom_claim"
+
+ // CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchage
+ CustomScopeImpersonatePrefix = "custom_scope:impersonate:"
+)
+
+type AuthRequest struct {
+ ID string
+ CreationDate time.Time
+ ApplicationID string
+ CallbackURI string
+ TransferState string
+ Prompt []string
+ UiLocales []language.Tag
+ LoginHint string
+ MaxAuthAge *time.Duration
+ UserID string
+ Scopes []string
+ ResponseType oidc.ResponseType
+ ResponseMode oidc.ResponseMode
+ Nonce string
+ CodeChallenge *OIDCCodeChallenge
+
+ done bool
+ authTime time.Time
+}
+
+// LogValue allows you to define which fields will be logged.
+// Implements the [slog.LogValuer]
+func (a *AuthRequest) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("id", a.ID),
+ slog.Time("creation_date", a.CreationDate),
+ slog.Any("scopes", a.Scopes),
+ slog.String("response_type", string(a.ResponseType)),
+ slog.String("app_id", a.ApplicationID),
+ slog.String("callback_uri", a.CallbackURI),
+ )
+}
+
+func (a *AuthRequest) GetID() string {
+ return a.ID
+}
+
+func (a *AuthRequest) GetACR() string {
+ return "" // we won't handle acr in this example
+}
+
+func (a *AuthRequest) GetAMR() []string {
+ // this example only uses password for authentication
+ if a.done {
+ return []string{"pwd"}
+ }
+ return nil
+}
+
+func (a *AuthRequest) GetAudience() []string {
+ return []string{a.ApplicationID} // this example will always just use the client_id as audience
+}
+
+func (a *AuthRequest) GetAuthTime() time.Time {
+ return a.authTime
+}
+
+func (a *AuthRequest) GetClientID() string {
+ return a.ApplicationID
+}
+
+func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
+ return CodeChallengeToOIDC(a.CodeChallenge)
+}
+
+func (a *AuthRequest) GetNonce() string {
+ return a.Nonce
+}
+
+func (a *AuthRequest) GetRedirectURI() string {
+ return a.CallbackURI
+}
+
+func (a *AuthRequest) GetResponseType() oidc.ResponseType {
+ return a.ResponseType
+}
+
+func (a *AuthRequest) GetResponseMode() oidc.ResponseMode {
+ return a.ResponseMode
+}
+
+func (a *AuthRequest) GetScopes() []string {
+ return a.Scopes
+}
+
+func (a *AuthRequest) GetState() string {
+ return a.TransferState
+}
+
+func (a *AuthRequest) GetSubject() string {
+ return a.UserID
+}
+
+func (a *AuthRequest) Done() bool {
+ return a.done
+}
+
+func PromptToInternal(oidcPrompt oidc.SpaceDelimitedArray) []string {
+ prompts := make([]string, 0, len(oidcPrompt))
+ for _, oidcPrompt := range oidcPrompt {
+ switch oidcPrompt {
+ case oidc.PromptNone,
+ oidc.PromptLogin,
+ oidc.PromptConsent,
+ oidc.PromptSelectAccount:
+ prompts = append(prompts, oidcPrompt)
+ }
+ }
+ return prompts
+}
+
+func MaxAgeToInternal(maxAge *uint) *time.Duration {
+ if maxAge == nil {
+ return nil
+ }
+ dur := time.Duration(*maxAge) * time.Second
+ return &dur
+}
+
+func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest {
+ return &AuthRequest{
+ CreationDate: time.Now(),
+ ApplicationID: authReq.ClientID,
+ CallbackURI: authReq.RedirectURI,
+ TransferState: authReq.State,
+ Prompt: PromptToInternal(authReq.Prompt),
+ UiLocales: authReq.UILocales,
+ LoginHint: authReq.LoginHint,
+ MaxAuthAge: MaxAgeToInternal(authReq.MaxAge),
+ UserID: userID,
+ Scopes: authReq.Scopes,
+ ResponseType: authReq.ResponseType,
+ ResponseMode: authReq.ResponseMode,
+ Nonce: authReq.Nonce,
+ CodeChallenge: &OIDCCodeChallenge{
+ Challenge: authReq.CodeChallenge,
+ Method: string(authReq.CodeChallengeMethod),
+ },
+ }
+}
+
+type AuthRequestWithSessionState struct {
+ *AuthRequest
+ SessionState string
+}
+
+func (a *AuthRequestWithSessionState) GetSessionState() string {
+ return a.SessionState
+}
+
+type OIDCCodeChallenge struct {
+ Challenge string
+ Method string
+}
+
+func CodeChallengeToOIDC(challenge *OIDCCodeChallenge) *oidc.CodeChallenge {
+ if challenge == nil {
+ return nil
+ }
+ challengeMethod := oidc.CodeChallengeMethodPlain
+ if challenge.Method == "S256" {
+ challengeMethod = oidc.CodeChallengeMethodS256
+ }
+ return &oidc.CodeChallenge{
+ Challenge: challenge.Challenge,
+ Method: challengeMethod,
+ }
+}
+
+// RefreshTokenRequestFromBusiness will simply wrap the storage RefreshToken to implement the op.RefreshTokenRequest interface
+func RefreshTokenRequestFromBusiness(token *RefreshToken) op.RefreshTokenRequest {
+ return &RefreshTokenRequest{token}
+}
+
+type RefreshTokenRequest struct {
+ *RefreshToken
+}
+
+func (r *RefreshTokenRequest) GetAMR() []string {
+ return r.AMR
+}
+
+func (r *RefreshTokenRequest) GetAudience() []string {
+ return r.Audience
+}
+
+func (r *RefreshTokenRequest) GetAuthTime() time.Time {
+ return r.AuthTime
+}
+
+func (r *RefreshTokenRequest) GetClientID() string {
+ return r.ApplicationID
+}
+
+func (r *RefreshTokenRequest) GetScopes() []string {
+ return r.Scopes
+}
+
+func (r *RefreshTokenRequest) GetSubject() string {
+ return r.UserID
+}
+
+func (r *RefreshTokenRequest) SetCurrentScopes(scopes []string) {
+ r.Scopes = scopes
+}
diff --git a/example/server/storage/storage.go b/example/server/storage/storage.go
new file mode 100644
index 0000000..d4315c6
--- /dev/null
+++ b/example/server/storage/storage.go
@@ -0,0 +1,933 @@
+package storage
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "errors"
+ "fmt"
+ "math/big"
+ "strings"
+ "sync"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/google/uuid"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+// serviceKey1 is a public key which will be used for the JWT Profile Authorization Grant
+// the corresponding private key is in the service-key1.json (for demonstration purposes)
+var serviceKey1 = &rsa.PublicKey{
+ N: func() *big.Int {
+ n, _ := new(big.Int).SetString("00f6d44fb5f34ac2033a75e73cb65ff24e6181edc58845e75a560ac21378284977bb055b1a75b714874e2a2641806205681c09abec76efd52cf40984edcf4c8ca09717355d11ac338f280d3e4c905b00543bdb8ee5a417496cb50cb0e29afc5a0d0471fd5a2fa625bd5281f61e6b02067d4fe7a5349eeae6d6a4300bcd86eef331", 16)
+ return n
+ }(),
+ E: 65537,
+}
+
+var (
+ _ op.Storage = &Storage{}
+ _ op.ClientCredentialsStorage = &Storage{}
+)
+
+// storage implements the op.Storage interface
+// typically you would implement this as a layer on top of your database
+// for simplicity this example keeps everything in-memory
+type Storage struct {
+ lock sync.Mutex
+ authRequests map[string]*AuthRequest
+ codes map[string]string
+ tokens map[string]*Token
+ clients map[string]*Client
+ userStore UserStore
+ services map[string]Service
+ refreshTokens map[string]*RefreshToken
+ signingKey signingKey
+ deviceCodes map[string]deviceAuthorizationEntry
+ userCodes map[string]string
+ serviceUsers map[string]*Client
+}
+
+type signingKey struct {
+ id string
+ algorithm jose.SignatureAlgorithm
+ key *rsa.PrivateKey
+}
+
+func (s *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm {
+ return s.algorithm
+}
+
+func (s *signingKey) Key() any {
+ return s.key
+}
+
+func (s *signingKey) ID() string {
+ return s.id
+}
+
+type publicKey struct {
+ signingKey
+}
+
+func (s *publicKey) ID() string {
+ return s.id
+}
+
+func (s *publicKey) Algorithm() jose.SignatureAlgorithm {
+ return s.algorithm
+}
+
+func (s *publicKey) Use() string {
+ return "sig"
+}
+
+func (s *publicKey) Key() any {
+ return &s.key.PublicKey
+}
+
+func NewStorage(userStore UserStore) *Storage {
+ return NewStorageWithClients(userStore, clients)
+}
+
+func NewStorageWithClients(userStore UserStore, clients map[string]*Client) *Storage {
+ key, _ := rsa.GenerateKey(rand.Reader, 2048)
+ return &Storage{
+ authRequests: make(map[string]*AuthRequest),
+ codes: make(map[string]string),
+ tokens: make(map[string]*Token),
+ refreshTokens: make(map[string]*RefreshToken),
+ clients: clients,
+ userStore: userStore,
+ services: map[string]Service{
+ userStore.ExampleClientID(): {
+ keys: map[string]*rsa.PublicKey{
+ "key1": serviceKey1,
+ },
+ },
+ },
+ signingKey: signingKey{
+ id: uuid.NewString(),
+ algorithm: jose.RS256,
+ key: key,
+ },
+ deviceCodes: make(map[string]deviceAuthorizationEntry),
+ userCodes: make(map[string]string),
+ serviceUsers: map[string]*Client{
+ "sid1": {
+ id: "sid1",
+ secret: "verysecret",
+ grantTypes: []oidc.GrantType{
+ oidc.GrantTypeClientCredentials,
+ },
+ accessTokenType: op.AccessTokenTypeBearer,
+ },
+ },
+ }
+}
+
+// CheckUsernamePassword implements the `authenticate` interface of the login
+func (s *Storage) CheckUsernamePassword(username, password, id string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ request, ok := s.authRequests[id]
+ if !ok {
+ return fmt.Errorf("request not found")
+ }
+
+ // for demonstration purposes we'll check we'll have a simple user store and
+ // a plain text password. For real world scenarios, be sure to have the password
+ // hashed and salted (e.g. using bcrypt)
+ user := s.userStore.GetUserByUsername(username)
+ if user != nil && user.Password == password {
+ // be sure to set user id into the auth request after the user was checked,
+ // so that you'll be able to get more information about the user after the login
+ request.UserID = user.ID
+
+ // you will have to change some state on the request to guide the user through possible multiple steps of the login process
+ // in this example we'll simply check the username / password and set a boolean to true
+ // therefore we will also just check this boolean if the request / login has been finished
+ request.done = true
+
+ request.authTime = time.Now()
+
+ return nil
+ }
+ return fmt.Errorf("username or password wrong")
+}
+
+func (s *Storage) CheckUsernamePasswordSimple(username, password string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ user := s.userStore.GetUserByUsername(username)
+ if user != nil && user.Password == password {
+ return nil
+ }
+ return fmt.Errorf("username or password wrong")
+}
+
+// CreateAuthRequest implements the op.Storage interface
+// it will be called after parsing and validation of the authentication request
+func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if len(authReq.Prompt) == 1 && authReq.Prompt[0] == "none" {
+ // With prompt=none, there is no way for the user to log in
+ // so return error right away.
+ return nil, oidc.ErrLoginRequired()
+ }
+
+ // typically, you'll fill your storage / storage model with the information of the passed object
+ request := authRequestToInternal(authReq, userID)
+
+ // you'll also have to create a unique id for the request (this might be done by your database; we'll use a uuid)
+ request.ID = uuid.NewString()
+
+ // and save it in your database (for demonstration purposed we will use a simple map)
+ s.authRequests[request.ID] = request
+
+ // finally, return the request (which implements the AuthRequest interface of the OP
+ return request, nil
+}
+
+// AuthRequestByID implements the op.Storage interface
+// it will be called after the Login UI redirects back to the OIDC endpoint
+func (s *Storage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ request, ok := s.authRequests[id]
+ if !ok {
+ return nil, fmt.Errorf("request not found")
+ }
+ return request, nil
+}
+
+// AuthRequestByCode implements the op.Storage interface
+// it will be called after parsing and validation of the token request (in an authorization code flow)
+func (s *Storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) {
+ // for this example we read the id by code and then get the request by id
+ requestID, ok := func() (string, bool) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ requestID, ok := s.codes[code]
+ return requestID, ok
+ }()
+ if !ok {
+ return nil, fmt.Errorf("code invalid or expired")
+ }
+ return s.AuthRequestByID(ctx, requestID)
+}
+
+// SaveAuthCode implements the op.Storage interface
+// it will be called after the authentication has been successful and before redirecting the user agent to the redirect_uri
+// (in an authorization code flow)
+func (s *Storage) SaveAuthCode(ctx context.Context, id string, code string) error {
+ // for this example we'll just save the authRequestID to the code
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ s.codes[code] = id
+ return nil
+}
+
+// DeleteAuthRequest implements the op.Storage interface
+// it will be called after creating the token response (id and access tokens) for a valid
+// - authentication request (in an implicit flow)
+// - token request (in an authorization code flow)
+func (s *Storage) DeleteAuthRequest(ctx context.Context, id string) error {
+ // you can simply delete all reference to the auth request
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ delete(s.authRequests, id)
+ for code, requestID := range s.codes {
+ if id == requestID {
+ delete(s.codes, code)
+ return nil
+ }
+ }
+ return nil
+}
+
+// CreateAccessToken implements the op.Storage interface
+// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...)
+func (s *Storage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
+ var applicationID string
+ switch req := request.(type) {
+ case *AuthRequest:
+ // if authenticated for an app (auth code / implicit flow) we must save the client_id to the token
+ applicationID = req.ApplicationID
+ case op.TokenExchangeRequest:
+ applicationID = req.GetClientID()
+ }
+
+ token, err := s.accessToken(applicationID, "", request.GetSubject(), request.GetAudience(), request.GetScopes())
+ if err != nil {
+ return "", time.Time{}, err
+ }
+ return token.ID, token.Expiration, nil
+}
+
+// CreateAccessAndRefreshTokens implements the op.Storage interface
+// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request)
+func (s *Storage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
+ // generate tokens via token exchange flow if request is relevant
+ if teReq, ok := request.(op.TokenExchangeRequest); ok {
+ return s.exchangeRefreshToken(ctx, teReq)
+ }
+
+ // get the information depending on the request type / implementation
+ applicationID, authTime, amr := getInfoFromRequest(request)
+
+ // if currentRefreshToken is empty (Code Flow) we will have to create a new refresh token
+ if currentRefreshToken == "" {
+ refreshTokenID := uuid.NewString()
+ accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes())
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+ refreshToken, err := s.createRefreshToken(accessToken, amr, authTime)
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+ return accessToken.ID, refreshToken, accessToken.Expiration, nil
+ }
+
+ // if we get here, the currentRefreshToken was not empty, so the call is a refresh token request
+ // we therefore will have to check the currentRefreshToken and renew the refresh token
+
+ newRefreshToken = uuid.NewString()
+
+ accessToken, err := s.accessToken(applicationID, newRefreshToken, request.GetSubject(), request.GetAudience(), request.GetScopes())
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+
+ if err := s.renewRefreshToken(currentRefreshToken, newRefreshToken, accessToken.ID); err != nil {
+ return "", "", time.Time{}, err
+ }
+
+ return accessToken.ID, newRefreshToken, accessToken.Expiration, nil
+}
+
+func (s *Storage) exchangeRefreshToken(ctx context.Context, request op.TokenExchangeRequest) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
+ applicationID := request.GetClientID()
+ authTime := request.GetAuthTime()
+
+ refreshTokenID := uuid.NewString()
+ accessToken, err := s.accessToken(applicationID, refreshTokenID, request.GetSubject(), request.GetAudience(), request.GetScopes())
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+
+ refreshToken, err := s.createRefreshToken(accessToken, nil, authTime)
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+
+ return accessToken.ID, refreshToken, accessToken.Expiration, nil
+}
+
+// TokenRequestByRefreshToken implements the op.Storage interface
+// it will be called after parsing and validation of the refresh token request
+func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ token, ok := s.refreshTokens[refreshToken]
+ if !ok {
+ return nil, fmt.Errorf("invalid refresh_token")
+ }
+ return RefreshTokenRequestFromBusiness(token), nil
+}
+
+// TerminateSession implements the op.Storage interface
+// it will be called after the user signed out, therefore the access and refresh token of the user of this client must be removed
+func (s *Storage) TerminateSession(ctx context.Context, userID string, clientID string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ for _, token := range s.tokens {
+ if token.ApplicationID == clientID && token.Subject == userID {
+ delete(s.tokens, token.ID)
+ delete(s.refreshTokens, token.RefreshTokenID)
+ }
+ }
+ return nil
+}
+
+// GetRefreshTokenInfo looks up a refresh token and returns the token id and user id.
+// If given something that is not a refresh token, it must return error.
+func (s *Storage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
+ refreshToken, ok := s.refreshTokens[token]
+ if !ok {
+ return "", "", op.ErrInvalidRefreshToken
+ }
+ return refreshToken.UserID, refreshToken.ID, nil
+}
+
+// RevokeToken implements the op.Storage interface
+// it will be called after parsing and validation of the token revocation request
+func (s *Storage) RevokeToken(ctx context.Context, tokenIDOrToken string, userID string, clientID string) *oidc.Error {
+ // a single token was requested to be removed
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ accessToken, ok := s.tokens[tokenIDOrToken] // tokenID
+ if ok {
+ if accessToken.ApplicationID != clientID {
+ return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
+ }
+ // if it is an access token, just remove it
+ // you could also remove the corresponding refresh token if really necessary
+ delete(s.tokens, accessToken.ID)
+ return nil
+ }
+ refreshToken, ok := s.refreshTokens[tokenIDOrToken] // token
+ if !ok {
+ // if the token is neither an access nor a refresh token, just ignore it, the expected behaviour of
+ // being not valid (anymore) is achieved
+ return nil
+ }
+ if refreshToken.ApplicationID != clientID {
+ return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
+ }
+ delete(s.refreshTokens, refreshToken.ID)
+ // if it is a refresh token, you will have to remove the access token as well
+ delete(s.tokens, refreshToken.AccessToken)
+ return nil
+}
+
+// SigningKey implements the op.Storage interface
+// it will be called when creating the OpenID Provider
+func (s *Storage) SigningKey(ctx context.Context) (op.SigningKey, error) {
+ // in this example the signing key is a static rsa.PrivateKey and the algorithm used is RS256
+ // you would obviously have a more complex implementation and store / retrieve the key from your database as well
+ return &s.signingKey, nil
+}
+
+// SignatureAlgorithms implements the op.Storage interface
+// it will be called to get the sign
+func (s *Storage) SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error) {
+ return []jose.SignatureAlgorithm{s.signingKey.algorithm}, nil
+}
+
+// KeySet implements the op.Storage interface
+// it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ...
+func (s *Storage) KeySet(ctx context.Context) ([]op.Key, error) {
+ // as mentioned above, this example only has a single signing key without key rotation,
+ // so it will directly use its public key
+ //
+ // when using key rotation you typically would store the public keys alongside the private keys in your database
+ // and give both of them an expiration date, with the public key having a longer lifetime
+ return []op.Key{&publicKey{s.signingKey}}, nil
+}
+
+// GetClientByClientID implements the op.Storage interface
+// it will be called whenever information (type, redirect_uris, ...) about the client behind the client_id is needed
+func (s *Storage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ client, ok := s.clients[clientID]
+ if !ok {
+ return nil, fmt.Errorf("client not found")
+ }
+ return RedirectGlobsClient(client), nil
+}
+
+// AuthorizeClientIDSecret implements the op.Storage interface
+// it will be called for validating the client_id, client_secret on token or introspection requests
+func (s *Storage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ client, ok := s.clients[clientID]
+ if !ok {
+ return fmt.Errorf("client not found")
+ }
+ // for this example we directly check the secret
+ // obviously you would not have the secret in plain text, but rather hashed and salted (e.g. using bcrypt)
+ if client.secret != clientSecret {
+ return fmt.Errorf("invalid secret")
+ }
+ return nil
+}
+
+// SetUserinfoFromScopes implements the op.Storage interface.
+// Provide an empty implementation and use SetUserinfoFromRequest instead.
+func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
+ return nil
+}
+
+// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
+// next major release, it will be required for op.Storage.
+// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
+func (s *Storage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
+ return s.setUserinfo(ctx, userinfo, token.GetSubject(), token.GetClientID(), scopes)
+}
+
+// SetUserinfoFromToken implements the op.Storage interface
+// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
+func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
+ token, ok := func() (*Token, bool) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ token, ok := s.tokens[tokenID]
+ return token, ok
+ }()
+ if !ok {
+ return fmt.Errorf("token is invalid or has expired")
+ }
+ // the userinfo endpoint should support CORS. If it's not possible to specify a specific origin in the CORS handler,
+ // and you have to specify a wildcard (*) origin, then you could also check here if the origin which called the userinfo endpoint here directly
+ // note that the origin can be empty (if called by a web client)
+ //
+ // if origin != "" {
+ // client, ok := s.clients[token.ApplicationID]
+ // if !ok {
+ // return fmt.Errorf("client not found")
+ // }
+ // if err := checkAllowedOrigins(client.allowedOrigins, origin); err != nil {
+ // return err
+ // }
+ //}
+ if token.Expiration.Before(time.Now()) {
+ return fmt.Errorf("token is expired")
+ }
+ return s.setUserinfo(ctx, userinfo, token.Subject, token.ApplicationID, token.Scopes)
+}
+
+// SetIntrospectionFromToken implements the op.Storage interface
+// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function
+func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
+ token, ok := func() (*Token, bool) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ token, ok := s.tokens[tokenID]
+ return token, ok
+ }()
+ if !ok {
+ return fmt.Errorf("token is invalid or has expired")
+ }
+ // check if the client is part of the requested audience
+ for _, aud := range token.Audience {
+ if aud == clientID {
+ // the introspection response only has to return a boolean (active) if the token is active
+ // this will automatically be done by the library if you don't return an error
+ // you can also return further information about the user / associated token
+ // e.g. the userinfo (equivalent to userinfo endpoint)
+
+ userInfo := new(oidc.UserInfo)
+ err := s.setUserinfo(ctx, userInfo, subject, clientID, token.Scopes)
+ if err != nil {
+ return err
+ }
+ introspection.SetUserInfo(userInfo)
+ //...and also the requested scopes...
+ introspection.Scope = token.Scopes
+ //...and the client the token was issued to
+ introspection.ClientID = token.ApplicationID
+ return nil
+ }
+ }
+ return fmt.Errorf("token is not valid for this client")
+}
+
+// GetPrivateClaimsFromScopes implements the op.Storage interface
+// it will be called for the creation of a JWT access token to assert claims for custom scopes
+func (s *Storage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) {
+ return s.getPrivateClaimsFromScopes(ctx, userID, clientID, scopes)
+}
+
+func (s *Storage) getPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) {
+ for _, scope := range scopes {
+ switch scope {
+ case CustomScope:
+ claims = appendClaim(claims, CustomClaim, customClaim(clientID))
+ }
+ }
+ return claims, nil
+}
+
+// GetKeyByIDAndClientID implements the op.Storage interface
+// it will be called to validate the signatures of a JWT (JWT Profile Grant and Authentication)
+func (s *Storage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ service, ok := s.services[clientID]
+ if !ok {
+ return nil, fmt.Errorf("clientID not found")
+ }
+ key, ok := service.keys[keyID]
+ if !ok {
+ return nil, fmt.Errorf("key not found")
+ }
+ return &jose.JSONWebKey{
+ KeyID: keyID,
+ Use: "sig",
+ Key: key,
+ }, nil
+}
+
+// ValidateJWTProfileScopes implements the op.Storage interface
+// it will be called to validate the scopes of a JWT Profile Authorization Grant request
+func (s *Storage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) {
+ allowedScopes := make([]string, 0)
+ for _, scope := range scopes {
+ if scope == oidc.ScopeOpenID {
+ allowedScopes = append(allowedScopes, scope)
+ }
+ }
+ return allowedScopes, nil
+}
+
+// Health implements the op.Storage interface
+func (s *Storage) Health(ctx context.Context) error {
+ return nil
+}
+
+// createRefreshToken will store a refresh_token in-memory based on the provided information
+func (s *Storage) createRefreshToken(accessToken *Token, amr []string, authTime time.Time) (string, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ token := &RefreshToken{
+ ID: accessToken.RefreshTokenID,
+ Token: accessToken.RefreshTokenID,
+ AuthTime: authTime,
+ AMR: amr,
+ ApplicationID: accessToken.ApplicationID,
+ UserID: accessToken.Subject,
+ Audience: accessToken.Audience,
+ Expiration: time.Now().Add(5 * time.Hour),
+ Scopes: accessToken.Scopes,
+ AccessToken: accessToken.ID,
+ }
+ s.refreshTokens[token.ID] = token
+ return token.Token, nil
+}
+
+// renewRefreshToken checks the provided refresh_token and creates a new one based on the current
+//
+// [Refresh Token Rotation] is implemented.
+//
+// [Refresh Token Rotation]: https://www.rfc-editor.org/rfc/rfc6819#section-5.2.2.3
+func (s *Storage) renewRefreshToken(currentRefreshToken, newRefreshToken, newAccessToken string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ refreshToken, ok := s.refreshTokens[currentRefreshToken]
+ if !ok {
+ return fmt.Errorf("invalid refresh token")
+ }
+ // deletes the refresh token
+ delete(s.refreshTokens, currentRefreshToken)
+
+ // delete the access token which was issued based on this refresh token
+ delete(s.tokens, refreshToken.AccessToken)
+
+ if refreshToken.Expiration.Before(time.Now()) {
+ return fmt.Errorf("expired refresh token")
+ }
+
+ // creates a new refresh token based on the current one
+ refreshToken.Token = newRefreshToken
+ refreshToken.ID = newRefreshToken
+ refreshToken.Expiration = time.Now().Add(5 * time.Hour)
+ refreshToken.AccessToken = newAccessToken
+ s.refreshTokens[newRefreshToken] = refreshToken
+ return nil
+}
+
+// accessToken will store an access_token in-memory based on the provided information
+func (s *Storage) accessToken(applicationID, refreshTokenID, subject string, audience, scopes []string) (*Token, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ token := &Token{
+ ID: uuid.NewString(),
+ ApplicationID: applicationID,
+ RefreshTokenID: refreshTokenID,
+ Subject: subject,
+ Audience: audience,
+ Expiration: time.Now().Add(5 * time.Minute),
+ Scopes: scopes,
+ }
+ s.tokens[token.ID] = token
+ return token, nil
+}
+
+// setUserinfo sets the info based on the user, scopes and if necessary the clientID
+func (s *Storage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, userID, clientID string, scopes []string) (err error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+ user := s.userStore.GetUserByID(userID)
+ if user == nil {
+ return fmt.Errorf("user not found")
+ }
+ for _, scope := range scopes {
+ switch scope {
+ case oidc.ScopeOpenID:
+ userInfo.Subject = user.ID
+ case oidc.ScopeEmail:
+ userInfo.Email = user.Email
+ userInfo.EmailVerified = oidc.Bool(user.EmailVerified)
+ case oidc.ScopeProfile:
+ userInfo.PreferredUsername = user.Username
+ userInfo.Name = user.FirstName + " " + user.LastName
+ userInfo.FamilyName = user.LastName
+ userInfo.GivenName = user.FirstName
+ userInfo.Locale = oidc.NewLocale(user.PreferredLanguage)
+ case oidc.ScopePhone:
+ userInfo.PhoneNumber = user.Phone
+ userInfo.PhoneNumberVerified = user.PhoneVerified
+ case CustomScope:
+ // you can also have a custom scope and assert public or custom claims based on that
+ userInfo.AppendClaims(CustomClaim, customClaim(clientID))
+ }
+ }
+ return nil
+}
+
+// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
+// it will be called to validate parsed Token Exchange Grant request
+func (s *Storage) ValidateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
+ if request.GetRequestedTokenType() == "" {
+ request.SetRequestedTokenType(oidc.RefreshTokenType)
+ }
+
+ // Just an example, some use cases might need this use case
+ if request.GetExchangeSubjectTokenType() == oidc.IDTokenType && request.GetRequestedTokenType() == oidc.RefreshTokenType {
+ return errors.New("exchanging id_token to refresh_token is not supported")
+ }
+
+ // Check impersonation permissions
+ if request.GetExchangeActor() == "" && !s.userStore.GetUserByID(request.GetExchangeSubject()).IsAdmin {
+ return errors.New("user doesn't have impersonation permission")
+ }
+
+ allowedScopes := make([]string, 0)
+ for _, scope := range request.GetScopes() {
+ if scope == oidc.ScopeAddress {
+ continue
+ }
+
+ if strings.HasPrefix(scope, CustomScopeImpersonatePrefix) {
+ subject := strings.TrimPrefix(scope, CustomScopeImpersonatePrefix)
+ request.SetSubject(subject)
+ }
+
+ allowedScopes = append(allowedScopes, scope)
+ }
+
+ request.SetCurrentScopes(allowedScopes)
+
+ return nil
+}
+
+// ValidateTokenExchangeRequest implements the op.TokenExchangeStorage interface
+// Common use case is to store request for audit purposes. For this example we skip the storing.
+func (s *Storage) CreateTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) error {
+ return nil
+}
+
+// GetPrivateClaimsFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
+// it will be called for the creation of an exchanged JWT access token to assert claims for custom scopes
+// plus adding token exchange specific claims related to delegation or impersonation
+func (s *Storage) GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]any, err error) {
+ claims, err = s.getPrivateClaimsFromScopes(ctx, "", request.GetClientID(), request.GetScopes())
+ if err != nil {
+ return nil, err
+ }
+
+ for k, v := range s.getTokenExchangeClaims(ctx, request) {
+ claims = appendClaim(claims, k, v)
+ }
+
+ return claims, nil
+}
+
+// SetUserinfoFromScopesForTokenExchange implements the op.TokenExchangeStorage interface
+// it will be called for the creation of an id_token - we are using the same private function as for other flows,
+// plus adding token exchange specific claims related to delegation or impersonation
+func (s *Storage) SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request op.TokenExchangeRequest) error {
+ err := s.setUserinfo(ctx, userinfo, request.GetSubject(), request.GetClientID(), request.GetScopes())
+ if err != nil {
+ return err
+ }
+
+ for k, v := range s.getTokenExchangeClaims(ctx, request) {
+ userinfo.AppendClaims(k, v)
+ }
+
+ return nil
+}
+
+func (s *Storage) getTokenExchangeClaims(ctx context.Context, request op.TokenExchangeRequest) (claims map[string]any) {
+ for _, scope := range request.GetScopes() {
+ switch {
+ case strings.HasPrefix(scope, CustomScopeImpersonatePrefix) && request.GetExchangeActor() == "":
+ // Set actor subject claim for impersonation flow
+ claims = appendClaim(claims, "act", map[string]any{
+ "sub": request.GetExchangeSubject(),
+ })
+ }
+ }
+
+ // Set actor subject claim for delegation flow
+ // if request.GetExchangeActor() != "" {
+ // claims = appendClaim(claims, "act", map[string]any{
+ // "sub": request.GetExchangeActor(),
+ // })
+ // }
+
+ return claims
+}
+
+// getInfoFromRequest returns the clientID, authTime and amr depending on the op.TokenRequest type / implementation
+func getInfoFromRequest(req op.TokenRequest) (clientID string, authTime time.Time, amr []string) {
+ authReq, ok := req.(*AuthRequest) // Code Flow (with scope offline_access)
+ if ok {
+ return authReq.ApplicationID, authReq.authTime, authReq.GetAMR()
+ }
+ refreshReq, ok := req.(*RefreshTokenRequest) // Refresh Token Request
+ if ok {
+ return refreshReq.ApplicationID, refreshReq.AuthTime, refreshReq.AMR
+ }
+ return "", time.Time{}, nil
+}
+
+// customClaim demonstrates how to return custom claims based on provided information
+func customClaim(clientID string) map[string]any {
+ return map[string]any{
+ "client": clientID,
+ "other": "stuff",
+ }
+}
+
+func appendClaim(claims map[string]any, claim string, value any) map[string]any {
+ if claims == nil {
+ claims = make(map[string]any)
+ }
+ claims[claim] = value
+ return claims
+}
+
+type deviceAuthorizationEntry struct {
+ deviceCode string
+ userCode string
+ state *op.DeviceAuthorizationState
+}
+
+func (s *Storage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if _, ok := s.clients[clientID]; !ok {
+ return errors.New("client not found")
+ }
+
+ if _, ok := s.userCodes[userCode]; ok {
+ return op.ErrDuplicateUserCode
+ }
+
+ s.deviceCodes[deviceCode] = deviceAuthorizationEntry{
+ deviceCode: deviceCode,
+ userCode: userCode,
+ state: &op.DeviceAuthorizationState{
+ ClientID: clientID,
+ Scopes: scopes,
+ Expires: expires,
+ },
+ }
+
+ s.userCodes[userCode] = deviceCode
+ return nil
+}
+
+func (s *Storage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[deviceCode]
+ if !ok || entry.state.ClientID != clientID {
+ return nil, errors.New("device code not found for client") // is there a standard not found error in the framework?
+ }
+
+ return entry.state, nil
+}
+
+func (s *Storage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[s.userCodes[userCode]]
+ if !ok {
+ return nil, errors.New("user code not found")
+ }
+
+ return entry.state, nil
+}
+
+func (s *Storage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ entry, ok := s.deviceCodes[s.userCodes[userCode]]
+ if !ok {
+ return errors.New("user code not found")
+ }
+
+ entry.state.Subject = subject
+ entry.state.Done = true
+ return nil
+}
+
+func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ s.deviceCodes[s.userCodes[userCode]].state.Denied = true
+ return nil
+}
+
+// AuthRequestDone is used by testing and is not required to implement op.Storage
+func (s *Storage) AuthRequestDone(id string) error {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ if req, ok := s.authRequests[id]; ok {
+ req.done = true
+ return nil
+ }
+
+ return errors.New("request not found")
+}
+
+func (s *Storage) ClientCredentials(ctx context.Context, clientID, clientSecret string) (op.Client, error) {
+ s.lock.Lock()
+ defer s.lock.Unlock()
+
+ client, ok := s.serviceUsers[clientID]
+ if !ok {
+ return nil, errors.New("wrong service user or password")
+ }
+ if client.secret != clientSecret {
+ return nil, errors.New("wrong service user or password")
+ }
+
+ return client, nil
+}
+
+func (s *Storage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (op.TokenRequest, error) {
+ client, ok := s.serviceUsers[clientID]
+ if !ok {
+ return nil, errors.New("wrong service user or password")
+ }
+
+ return &oidc.JWTTokenRequest{
+ Subject: client.id,
+ Audience: []string{clientID},
+ Scopes: scopes,
+ }, nil
+}
diff --git a/example/server/storage/storage_dynamic.go b/example/server/storage/storage_dynamic.go
new file mode 100644
index 0000000..765d29a
--- /dev/null
+++ b/example/server/storage/storage_dynamic.go
@@ -0,0 +1,281 @@
+package storage
+
+import (
+ "context"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+type multiStorage struct {
+ issuers map[string]*Storage
+}
+
+// NewMultiStorage implements the op.Storage interface by wrapping multiple storage structs
+// and selecting them by the calling issuer
+func NewMultiStorage(issuers []string) *multiStorage {
+ s := make(map[string]*Storage)
+ for _, issuer := range issuers {
+ s[issuer] = NewStorage(NewUserStore(issuer))
+ }
+ return &multiStorage{issuers: s}
+}
+
+// CheckUsernamePassword implements the `authenticate` interface of the login
+func (s *multiStorage) CheckUsernamePassword(ctx context.Context, username, password, id string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.CheckUsernamePassword(username, password, id)
+}
+
+// CreateAuthRequest implements the op.Storage interface
+// it will be called after parsing and validation of the authentication request
+func (s *multiStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.CreateAuthRequest(ctx, authReq, userID)
+}
+
+// AuthRequestByID implements the op.Storage interface
+// it will be called after the Login UI redirects back to the OIDC endpoint
+func (s *multiStorage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.AuthRequestByID(ctx, id)
+}
+
+// AuthRequestByCode implements the op.Storage interface
+// it will be called after parsing and validation of the token request (in an authorization code flow)
+func (s *multiStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.AuthRequestByCode(ctx, code)
+}
+
+// SaveAuthCode implements the op.Storage interface
+// it will be called after the authentication has been successful and before redirecting the user agent to the redirect_uri
+// (in an authorization code flow)
+func (s *multiStorage) SaveAuthCode(ctx context.Context, id string, code string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.SaveAuthCode(ctx, id, code)
+}
+
+// DeleteAuthRequest implements the op.Storage interface
+// it will be called after creating the token response (id and access tokens) for a valid
+// - authentication request (in an implicit flow)
+// - token request (in an authorization code flow)
+func (s *multiStorage) DeleteAuthRequest(ctx context.Context, id string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.DeleteAuthRequest(ctx, id)
+}
+
+// CreateAccessToken implements the op.Storage interface
+// it will be called for all requests able to return an access token (Authorization Code Flow, Implicit Flow, JWT Profile, ...)
+func (s *multiStorage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return "", time.Time{}, err
+ }
+ return storage.CreateAccessToken(ctx, request)
+}
+
+// CreateAccessAndRefreshTokens implements the op.Storage interface
+// it will be called for all requests able to return an access and refresh token (Authorization Code Flow, Refresh Token Request)
+func (s *multiStorage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+ return storage.CreateAccessAndRefreshTokens(ctx, request, currentRefreshToken)
+}
+
+// TokenRequestByRefreshToken implements the op.Storage interface
+// it will be called after parsing and validation of the refresh token request
+func (s *multiStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.TokenRequestByRefreshToken(ctx, refreshToken)
+}
+
+// TerminateSession implements the op.Storage interface
+// it will be called after the user signed out, therefore the access and refresh token of the user of this client must be removed
+func (s *multiStorage) TerminateSession(ctx context.Context, userID string, clientID string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.TerminateSession(ctx, userID, clientID)
+}
+
+// GetRefreshTokenInfo looks up a refresh token and returns the token id and user id.
+// If given something that is not a refresh token, it must return error.
+func (s *multiStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return "", "", err
+ }
+ return storage.GetRefreshTokenInfo(ctx, clientID, token)
+}
+
+// RevokeToken implements the op.Storage interface
+// it will be called after parsing and validation of the token revocation request
+func (s *multiStorage) RevokeToken(ctx context.Context, token string, userID string, clientID string) *oidc.Error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.RevokeToken(ctx, token, userID, clientID)
+}
+
+// SigningKey implements the op.Storage interface
+// it will be called when creating the OpenID Provider
+func (s *multiStorage) SigningKey(ctx context.Context) (op.SigningKey, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.SigningKey(ctx)
+}
+
+// SignatureAlgorithms implements the op.Storage interface
+// it will be called to get the sign
+func (s *multiStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.SignatureAlgorithms(ctx)
+}
+
+// KeySet implements the op.Storage interface
+// it will be called to get the current (public) keys, among others for the keys_endpoint or for validating access_tokens on the userinfo_endpoint, ...
+func (s *multiStorage) KeySet(ctx context.Context) ([]op.Key, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.KeySet(ctx)
+}
+
+// GetClientByClientID implements the op.Storage interface
+// it will be called whenever information (type, redirect_uris, ...) about the client behind the client_id is needed
+func (s *multiStorage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.GetClientByClientID(ctx, clientID)
+}
+
+// AuthorizeClientIDSecret implements the op.Storage interface
+// it will be called for validating the client_id, client_secret on token or introspection requests
+func (s *multiStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
+}
+
+// SetUserinfoFromScopes implements the op.Storage interface.
+// Provide an empty implementation and use SetUserinfoFromRequest instead.
+func (s *multiStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.SetUserinfoFromScopes(ctx, userinfo, userID, clientID, scopes)
+}
+
+// SetUserinfoFromRequests implements the op.CanSetUserinfoFromRequest interface. In the
+// next major release, it will be required for op.Storage.
+// It will be called for the creation of an id_token, so we'll just pass it to the private function without any further check
+func (s *multiStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, token op.IDTokenRequest, scopes []string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.SetUserinfoFromRequest(ctx, userinfo, token, scopes)
+}
+
+// SetUserinfoFromToken implements the op.Storage interface
+// it will be called for the userinfo endpoint, so we read the token and pass the information from that to the private function
+func (s *multiStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.SetUserinfoFromToken(ctx, userinfo, tokenID, subject, origin)
+}
+
+// SetIntrospectionFromToken implements the op.Storage interface
+// it will be called for the introspection endpoint, so we read the token and pass the information from that to the private function
+func (s *multiStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return err
+ }
+ return storage.SetIntrospectionFromToken(ctx, introspection, tokenID, subject, clientID)
+}
+
+// GetPrivateClaimsFromScopes implements the op.Storage interface
+// it will be called for the creation of a JWT access token to assert claims for custom scopes
+func (s *multiStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]any, err error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.GetPrivateClaimsFromScopes(ctx, userID, clientID, scopes)
+}
+
+// GetKeyByIDAndClientID implements the op.Storage interface
+// it will be called to validate the signatures of a JWT (JWT Profile Grant and Authentication)
+func (s *multiStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.GetKeyByIDAndClientID(ctx, keyID, userID)
+}
+
+// ValidateJWTProfileScopes implements the op.Storage interface
+// it will be called to validate the scopes of a JWT Profile Authorization Grant request
+func (s *multiStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) {
+ storage, err := s.storageFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return storage.ValidateJWTProfileScopes(ctx, userID, scopes)
+}
+
+// Health implements the op.Storage interface
+func (s *multiStorage) Health(ctx context.Context) error {
+ return nil
+}
+
+func (s *multiStorage) storageFromContext(ctx context.Context) (*Storage, *oidc.Error) {
+ storage, ok := s.issuers[op.IssuerFromContext(ctx)]
+ if !ok {
+ return nil, oidc.ErrInvalidRequest().WithDescription("invalid issuer")
+ }
+ return storage, nil
+}
diff --git a/example/server/storage/token.go b/example/server/storage/token.go
new file mode 100644
index 0000000..beab38c
--- /dev/null
+++ b/example/server/storage/token.go
@@ -0,0 +1,26 @@
+package storage
+
+import "time"
+
+type Token struct {
+ ID string
+ ApplicationID string
+ Subject string
+ RefreshTokenID string
+ Audience []string
+ Expiration time.Time
+ Scopes []string
+}
+
+type RefreshToken struct {
+ ID string
+ Token string
+ AuthTime time.Time
+ AMR []string
+ Audience []string
+ UserID string
+ ApplicationID string
+ Expiration time.Time
+ Scopes []string
+ AccessToken string // Token.ID
+}
diff --git a/example/server/storage/user.go b/example/server/storage/user.go
new file mode 100644
index 0000000..ed8cdfa
--- /dev/null
+++ b/example/server/storage/user.go
@@ -0,0 +1,102 @@
+package storage
+
+import (
+ "crypto/rsa"
+ "encoding/json"
+ "os"
+ "strings"
+
+ "golang.org/x/text/language"
+)
+
+type User struct {
+ ID string
+ Username string
+ Password string
+ FirstName string
+ LastName string
+ Email string
+ EmailVerified bool
+ Phone string
+ PhoneVerified bool
+ PreferredLanguage language.Tag
+ IsAdmin bool
+}
+
+type Service struct {
+ keys map[string]*rsa.PublicKey
+}
+
+type UserStore interface {
+ GetUserByID(string) *User
+ GetUserByUsername(string) *User
+ ExampleClientID() string
+}
+
+type userStore struct {
+ users map[string]*User
+}
+
+func StoreFromFile(path string) (UserStore, error) {
+ users := map[string]*User{}
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ if err := json.Unmarshal(data, &users); err != nil {
+ return nil, err
+ }
+ return userStore{users}, nil
+}
+
+func NewUserStore(issuer string) UserStore {
+ hostname := strings.Split(strings.Split(issuer, "://")[1], ":")[0]
+ return userStore{
+ users: map[string]*User{
+ "id1": {
+ ID: "id1",
+ Username: "test-user@" + hostname,
+ Password: "verysecure",
+ FirstName: "Test",
+ LastName: "User",
+ Email: "test-user@zitadel.ch",
+ EmailVerified: true,
+ Phone: "",
+ PhoneVerified: false,
+ PreferredLanguage: language.German,
+ IsAdmin: true,
+ },
+ "id2": {
+ ID: "id2",
+ Username: "test-user2",
+ Password: "verysecure",
+ FirstName: "Test",
+ LastName: "User2",
+ Email: "test-user2@zitadel.ch",
+ EmailVerified: true,
+ Phone: "",
+ PhoneVerified: false,
+ PreferredLanguage: language.German,
+ IsAdmin: false,
+ },
+ },
+ }
+}
+
+// ExampleClientID is only used in the example server
+func (u userStore) ExampleClientID() string {
+ return "service"
+}
+
+func (u userStore) GetUserByID(id string) *User {
+ return u.users[id]
+}
+
+func (u userStore) GetUserByUsername(username string) *User {
+ for _, user := range u.users {
+ if user.Username == username {
+ return user
+ }
+ }
+ return nil
+}
diff --git a/example/server/storage/user_test.go b/example/server/storage/user_test.go
new file mode 100644
index 0000000..c2e2212
--- /dev/null
+++ b/example/server/storage/user_test.go
@@ -0,0 +1,70 @@
+package storage
+
+import (
+ "os"
+ "path"
+ "reflect"
+ "testing"
+
+ "golang.org/x/text/language"
+)
+
+func TestStoreFromFile(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ pathToFile string
+ content string
+ want UserStore
+ wantErr bool
+ }{
+ {
+ name: "normal user file",
+ pathToFile: "userfile.json",
+ content: `{
+ "id1": {
+ "ID": "id1",
+ "EmailVerified": true,
+ "PreferredLanguage": "DE"
+ }
+ }`,
+ want: userStore{map[string]*User{
+ "id1": {
+ ID: "id1",
+ EmailVerified: true,
+ PreferredLanguage: language.German,
+ },
+ }},
+ },
+ {
+ name: "malformed file",
+ pathToFile: "whatever",
+ content: "not a json just a text",
+ wantErr: true,
+ },
+ {
+ name: "not existing file",
+ pathToFile: "what/ever/file",
+ wantErr: true,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ actualPath := path.Join(t.TempDir(), tc.pathToFile)
+
+ if tc.content != "" && tc.pathToFile != "" {
+ if err := os.WriteFile(actualPath, []byte(tc.content), 0666); err != nil {
+ t.Fatalf("cannot create file with test content: %q", tc.content)
+ }
+ }
+ result, err := StoreFromFile(actualPath)
+ if err != nil && !tc.wantErr {
+ t.Errorf("StoreFromFile(%q) returned unexpected error %q", tc.pathToFile, err)
+ } else if err == nil && tc.wantErr {
+ t.Errorf("StoreFromFile(%q) did not return an expected error", tc.pathToFile)
+ }
+ if !tc.wantErr && !reflect.DeepEqual(tc.want, result.(userStore)) {
+ t.Errorf("expected StoreFromFile(%q) = %v, but got %v",
+ tc.pathToFile, tc.want, result)
+ }
+ })
+ }
+}
diff --git a/go.mod b/go.mod
index 24a1b57..a0f42c4 100644
--- a/go.mod
+++ b/go.mod
@@ -1,25 +1,40 @@
-module github.com/caos/oidc
+module git.christmann.info/LARA/zitadel-oidc/v3
-go 1.13
+go 1.23.7
+
+toolchain go1.24.1
require (
- github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a
- github.com/golang/mock v1.4.3
- github.com/google/go-cmp v0.4.1 // indirect
+ github.com/bmatcuk/doublestar/v4 v4.8.1
+ github.com/go-chi/chi/v5 v5.2.1
+ github.com/go-jose/go-jose/v4 v4.0.5
+ github.com/golang/mock v1.6.0
github.com/google/go-github/v31 v31.0.0
- github.com/google/uuid v1.1.1
- github.com/gorilla/handlers v1.4.2
- github.com/gorilla/mux v1.7.4
- github.com/gorilla/schema v1.1.0
- github.com/gorilla/securecookie v1.1.1
- github.com/kr/pretty v0.1.0 // indirect
- github.com/sirupsen/logrus v1.6.0
- github.com/stretchr/testify v1.6.1
- golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
- golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
- golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
- golang.org/x/text v0.3.3
- google.golang.org/appengine v1.6.5 // indirect
- gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
- gopkg.in/square/go-jose.v2 v2.5.1
+ github.com/google/uuid v1.6.0
+ github.com/gorilla/securecookie v1.1.2
+ github.com/jeremija/gosubmit v0.2.8
+ github.com/muhlemmer/gu v0.3.1
+ github.com/muhlemmer/httpforwarded v0.1.0
+ github.com/rs/cors v1.11.1
+ github.com/sirupsen/logrus v1.9.3
+ github.com/stretchr/testify v1.10.0
+ github.com/zitadel/logging v0.6.2
+ github.com/zitadel/schema v1.3.1
+ go.opentelemetry.io/otel v1.29.0
+ golang.org/x/oauth2 v0.30.0
+ golang.org/x/text v0.26.0
+)
+
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/go-logr/logr v1.4.2 // indirect
+ github.com/go-logr/stdr v1.2.2 // indirect
+ github.com/google/go-querystring v1.1.0 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ go.opentelemetry.io/otel/metric v1.29.0 // indirect
+ go.opentelemetry.io/otel/trace v1.29.0 // indirect
+ golang.org/x/crypto v0.36.0 // indirect
+ golang.org/x/net v0.38.0 // indirect
+ golang.org/x/sys v0.31.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 5af7d7e..4835505 100644
--- a/go.sum
+++ b/go.sum
@@ -1,104 +1,108 @@
-cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a h1:HOU/3xL/afsZ+2aCstfJlrzRkwYMTFR1TIEgps5ny8s=
-github.com/caos/logging v0.0.0-20191210002624-b3260f690a6a/go.mod h1:9LKiDE2ChuGv6CHYif/kiugrfEXu9AwDiFWSreX7Wp0=
+github.com/bmatcuk/doublestar/v4 v4.8.1 h1:54Bopc5c2cAvhLRAzqOGCYHYyhcDHsFF4wWIR5wKP38=
+github.com/bmatcuk/doublestar/v4 v4.8.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw=
-github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
-github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
+github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
+github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
+github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
+github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
+github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
+github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
+github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
+github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
+github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
+github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
+github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0=
-github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github/v31 v31.0.0 h1:JJUxlP9lFK+ziXKimTCprajMApV1ecWD4NB6CCb0plo=
github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM=
-github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
-github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
-github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/gorilla/handlers v1.4.2 h1:0QniY0USkHQ1RGCLfKxeNHK9bkDHGRYGNDFBCS+YARg=
-github.com/gorilla/handlers v1.4.2/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ=
-github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
-github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
-github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY=
-github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
-github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
-github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
-github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
-github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=
-github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
-github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
+github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
+github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
+github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
+github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
+github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA=
+github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
+github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
+github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM=
+github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
+github.com/muhlemmer/httpforwarded v0.1.0 h1:x4DLrzXdliq8mprgUMR0olDvHGkou5BJsK/vWUetyzY=
+github.com/muhlemmer/httpforwarded v0.1.0/go.mod h1:yo9czKedo2pdZhoXe+yDkGVbU0TJ0q9oQ90BVoDEtw0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
-github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
-github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
-github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
+github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
+github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=
-github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
-github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
-github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
-github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
-github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
+github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU=
+github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4=
+github.com/zitadel/schema v1.3.1 h1:QT3kwiRIRXXLVAs6gCK/u044WmUVh6IlbLXUsn6yRQU=
+github.com/zitadel/schema v1.3.1/go.mod h1:071u7D2LQacy1HAN+YnMd/mx1qVE2isb0Mjeqg46xnU=
+go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
+go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
+go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
+go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
+go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
+go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 h1:anGSYQpPhQwXlwsu5wmfq0nWkCNaMEMUwAv13Y92hd8=
-golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
-golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
+golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
-golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 h1:e6HwijUxhDe+hPNjZQQn9bA5PW3vNmnN64U2ZW759Lk=
-golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
+golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
+golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
-golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c h1:HjRaKPaiWks0f5tA6ELVF7ZfqSppfPwOEEAvsrKUTO4=
-golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
-golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab h1:FvshnhkKW+LO3HWHodML8kuVX8rnJTxKm9dFPuI68UM=
-golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
+golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
-golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
-google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
-google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
-google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
-google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
-gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
-gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
-gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-rsc.io/quote/v3 v3.1.0 h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY=
-rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
-rsc.io/sampler v1.3.0 h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4=
-rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/testutil/gen/gen.go b/internal/testutil/gen/gen.go
new file mode 100644
index 0000000..3e44b7d
--- /dev/null
+++ b/internal/testutil/gen/gen.go
@@ -0,0 +1,58 @@
+// Package gen allows generating of example tokens and claims.
+//
+// go run ./internal/testutil/gen
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+var custom = map[string]any{
+ "foo": "Hello, World!",
+ "bar": struct {
+ Count int `json:"count,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+ }{
+ Count: 22,
+ Tags: []string{"some", "tags"},
+ },
+}
+
+func main() {
+ enc := json.NewEncoder(os.Stdout)
+ enc.SetIndent("", " ")
+
+ accessToken, atClaims := tu.NewAccessTokenCustom(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.AddDate(99, 0, 0), tu.ValidJWTID,
+ tu.ValidClientID, tu.ValidSkew, custom,
+ )
+ atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
+ if err != nil {
+ panic(err)
+ }
+
+ idToken, idClaims := tu.NewIDTokenCustom(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.AddDate(99, 0, 0), tu.ValidAuthTime,
+ tu.ValidNonce, tu.ValidACR, tu.ValidAMR, tu.ValidClientID,
+ tu.ValidSkew, atHash, custom,
+ )
+
+ fmt.Println("access token claims:")
+ if err := enc.Encode(atClaims); err != nil {
+ panic(err)
+ }
+ fmt.Printf("access token:\n%s\n", accessToken)
+
+ fmt.Println("ID token claims:")
+ if err := enc.Encode(idClaims); err != nil {
+ panic(err)
+ }
+ fmt.Printf("ID token:\n%s\n", idToken)
+}
diff --git a/internal/testutil/token.go b/internal/testutil/token.go
new file mode 100644
index 0000000..72d08c5
--- /dev/null
+++ b/internal/testutil/token.go
@@ -0,0 +1,180 @@
+// Package testuril helps setting up required data for testing,
+// such as tokens, claims and verifiers.
+package testutil
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/muhlemmer/gu"
+)
+
+// KeySet implements oidc.Keys
+type KeySet struct{}
+
+// VerifySignature implments op.KeySet.
+func (KeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
+ if err = ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ return jws.Verify(WebKey.Public())
+}
+
+// use a reproducible signing key
+const webkeyJSON = `{"kty":"RSA","kid":"1","alg":"PS512","n":"x6JoG8t2Li68JSwPwnh51TvHYFf3z72tQ3wmJG3VosU6MdJF0gSTCIwflOJ38OWE6hYtN1WAeyBy2CYdnXd1QZzkK_apGK4M7hsNA9jCTg8NOZjLPL0ww1jp7313Skla7mbm90uNdg4TUNp2n_r-sCYywI-9cfSlhzLSksxKK_BRdzy6xW20daAcI-mErQXIcvdYIguunJk_uTb8kJedsWMcQ4Mb57QujUok2Z2YabWyb9Fi1_StixXJvd_WEu93SHNMORB0u6ymnO3aZJdATLdhtcP-qsVicQhffpqVazmZQPf7K-7n4I5vJE4g9XXzZ2dSKSp3Ewe_nna_2kvbCw","e":"AQAB","d":"sl3F_QeF2O-CxQegMRYpbL6Tfd47GM6VDxXOkn_cACmNvFPudB4ILPvdf830cjTv06Lq1WS8fcZZNgygK0A_cNc3-pvRK67e-KMMtuIlgU7rdwmwlN1Iw1Ee-w6z1ZjC-PzR4iQMCW28DmKS2I-OnV4TvH7xOe7nMmvTPrvujV__YKfUxvAWXJG7_wtaJBGplezn5nNsKG2Ot9h0mhMdYUgGC36wLxo3Q5d4m79EXQYdhm89EfxogwvMmHRes5PNpHRuDZRHGAI4RZi2KvgmqF07e1Qdq4TqbQnY5pCYrdjqvEFFjGC6jTE-ak_b21FcSVy-9aZHyf04U4g5-cIUEQ","p":"7AaicFryJCHRekdSkx8tfPxaSiyEuN8jhP9cLqs4rLkIbrSHmanPhjnLe-Tlh3icQ8hPoy6WC8ktLwsrzbfGIh4U_zgAfvtD1Y_lZM-YSWZsxqlrGiI5do11iVzzoy4a1XdkgOjHQz9y6J-uoA9jY8ILG7VaEZQnaYwWZV3cspk","q":"2Ide9hlwthXJQJYqI0mibM5BiGBxJ4CafPmF1DYNXggBCczZ6ERGReNTGM_AEhy5mvLXUH6uBSOJlfHTYzx49C1GgIO3hEWVEGAKAytVRL6RfAkVSOXMQUp-HjXKpGg_Nx1SJxQf3rulbW8HXO4KqIlloyIXpPQSK7jB8A4hJUM","dp":"1nmc6F4sRNsaQHRJO_mL21RxM4_KtzfFThjCCoJ6iLHHUNnpkp_1PTKNjrLMRFM8JHgErfMqU-FmlqYfEtvZRq1xRQ39nWX0GT-eIwJljuVtGQVglqnc77bRxJXbqz-9EJdik6VzVM92Op7IDxiMp1zvvSkJhInNWqL6wvgNEZk","dq":"dlHizlAwiw90ndpwxD-khhhfLwqkSpW31br0KnYu78cn6hcKrCVC0UXbTp-XsU4JDmbMyauvpBc7Q7iVbpDI94UWFXvkeF8diYkxb3HqclpAXasI-oC4EKWILTHvvc9JW_Clx7zzfV7Ekvws5dcd8-LAq1gh232TwFiBgY_3BMk","qi":"E1k_9W3odXgcmIP2PCJztE7hB7jeuAL1ElAY88VJBBPY670uwOEjKL2VfQuz9q9IjzLAvcgf7vS9blw2RHP_XqHqSOlJWGwvMQTF0Q8zLknCgKt8q7HQQNWIJcBZ8qdUVn02-qf4E3tgZ3JHaHNs8imA_L-__WoUmzC4z5jH_lM"}`
+
+const SignatureAlgorithm = jose.RS256
+
+var (
+ WebKey jose.JSONWebKey
+ Signer jose.Signer
+)
+
+func init() {
+ err := json.Unmarshal([]byte(webkeyJSON), &WebKey)
+ if err != nil {
+ panic(err)
+ }
+ Signer, err = jose.NewSigner(jose.SigningKey{Algorithm: SignatureAlgorithm, Key: WebKey}, nil)
+ if err != nil {
+ panic(err)
+ }
+}
+
+type JWTProfileKeyStorage struct{}
+
+func (JWTProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID string, clientID string) (*jose.JSONWebKey, error) {
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ return gu.Ptr(WebKey.Public()), nil
+}
+
+func signEncodeTokenClaims(claims any) string {
+ payload, err := json.Marshal(claims)
+ if err != nil {
+ panic(err)
+ }
+ object, err := Signer.Sign(payload)
+ if err != nil {
+ panic(err)
+ }
+ token, err := object.CompactSerialize()
+ if err != nil {
+ panic(err)
+ }
+ return token
+}
+
+func claimsMap(claims any) map[string]any {
+ data, err := json.Marshal(claims)
+ if err != nil {
+ panic(err)
+ }
+ dst := make(map[string]any)
+ if err = json.Unmarshal(data, &dst); err != nil {
+ panic(err)
+ }
+ return dst
+}
+
+func NewIDTokenCustom(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration, atHash string, custom map[string]any) (string, *oidc.IDTokenClaims) {
+ claims := oidc.NewIDTokenClaims(issuer, subject, audience, expiration, authTime, nonce, acr, amr, clientID, skew)
+ claims.AccessTokenHash = atHash
+ claims.Claims = custom
+ token := signEncodeTokenClaims(claims)
+
+ // set this so that assertion in tests will work
+ claims.SignatureAlg = SignatureAlgorithm
+ claims.Claims = claimsMap(claims)
+ return token, claims
+}
+
+// NewIDToken creates a new IDTokenClaims with passed data and returns a signed token and claims.
+func NewIDToken(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration, atHash string) (string, *oidc.IDTokenClaims) {
+ return NewIDTokenCustom(issuer, subject, audience, expiration, authTime, nonce, acr, amr, clientID, skew, atHash, nil)
+}
+
+func NewAccessTokenCustom(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration, custom map[string]any) (string, *oidc.AccessTokenClaims) {
+ claims := oidc.NewAccessTokenClaims(issuer, subject, audience, expiration, jwtid, clientID, skew)
+ claims.Claims = custom
+ token := signEncodeTokenClaims(claims)
+
+ // set this so that assertion in tests will work
+ claims.SignatureAlg = SignatureAlgorithm
+ claims.Claims = claimsMap(claims)
+ return token, claims
+}
+
+// NewAcccessToken creates a new AccessTokenClaims with passed data and returns a signed token and claims.
+func NewAccessToken(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) (string, *oidc.AccessTokenClaims) {
+ return NewAccessTokenCustom(issuer, subject, audience, expiration, jwtid, clientID, skew, nil)
+}
+
+func NewJWTProfileAssertion(issuer, clientID string, audience []string, issuedAt, expiration time.Time) (string, *oidc.JWTTokenRequest) {
+ req := &oidc.JWTTokenRequest{
+ Issuer: issuer,
+ Subject: clientID,
+ Audience: audience,
+ ExpiresAt: oidc.FromTime(expiration),
+ IssuedAt: oidc.FromTime(issuedAt),
+ }
+ // make sure the private claim map is set correctly
+ data, err := json.Marshal(req)
+ if err != nil {
+ panic(err)
+ }
+ if err = json.Unmarshal(data, req); err != nil {
+ panic(err)
+ }
+ return signEncodeTokenClaims(req), req
+}
+
+const InvalidSignatureToken = `eyJhbGciOiJQUzUxMiJ9.eyJpc3MiOiJsb2NhbC5jb20iLCJzdWIiOiJ0aW1AbG9jYWwuY29tIiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImV4cCI6MTY3Nzg0MDQzMSwiaWF0IjoxNjc3ODQwMzcwLCJhdXRoX3RpbWUiOjE2Nzc4NDAzMTAsIm5vbmNlIjoiMTIzNDUiLCJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF6cCI6IjU1NTY2NiJ9.DtZmvVkuE4Hw48ijBMhRJbxEWCr_WEYuPQBMY73J9TP6MmfeNFkjVJf4nh4omjB9gVLnQ-xhEkNOe62FS5P0BB2VOxPuHZUj34dNspCgG3h98fGxyiMb5vlIYAHDF9T-w_LntlYItohv63MmdYR-hPpAqjXE7KOfErf-wUDGE9R3bfiQ4HpTdyFJB1nsToYrZ9lhP2mzjTCTs58ckZfQ28DFHn_lfHWpR4rJBgvLx7IH4rMrUayr09Ap-PxQLbv0lYMtmgG1z3JK8MXnuYR0UJdZnEIezOzUTlThhCXB-nvuAXYjYxZZTR0FtlgZUHhIpYK0V2abf_Q_Or36akNCUg`
+
+// These variables always result in a valid token
+var (
+ ValidIssuer = "local.com"
+ ValidSubject = "tim@local.com"
+ ValidAudience = []string{"unit", "test"}
+ ValidAuthTime = time.Now().Add(-time.Minute) // authtime is always 1 minute in the past
+ ValidExpiration = ValidAuthTime.Add(2 * time.Minute) // token is always 1 more minute available
+ ValidJWTID = "9876"
+ ValidNonce = "12345"
+ ValidACR = "something"
+ ValidAMR = []string{"foo", "bar"}
+ ValidClientID = "555666"
+ ValidSkew = time.Second
+)
+
+// ValidIDToken returns a token and claims that are in the token.
+// It uses the Valid* global variables and the token will always
+// pass verification.
+func ValidIDToken() (string, *oidc.IDTokenClaims) {
+ return NewIDToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidAuthTime, ValidNonce, ValidACR, ValidAMR, ValidClientID, ValidSkew, "")
+}
+
+// ValidAccessToken returns a token and claims that are in the token.
+// It uses the Valid* global variables and the token always passes
+// verification within the same test run.
+func ValidAccessToken() (string, *oidc.AccessTokenClaims) {
+ return NewAccessToken(ValidIssuer, ValidSubject, ValidAudience, ValidExpiration, ValidJWTID, ValidClientID, ValidSkew)
+}
+
+func ValidJWTProfileAssertion() (string, *oidc.JWTTokenRequest) {
+ return NewJWTProfileAssertion(ValidClientID, ValidClientID, []string{ValidIssuer}, time.Now(), ValidExpiration)
+}
+
+// ACRVerify is a oidc.ACRVerifier func.
+func ACRVerify(acr string) error {
+ if acr != ValidACR {
+ return errors.New("invalid acr")
+ }
+ return nil
+}
diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go
deleted file mode 100644
index c7328e4..0000000
--- a/pkg/cli/cli.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package cli
-
-import (
- "context"
- "fmt"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
- "github.com/caos/oidc/pkg/utils"
- "github.com/google/uuid"
- "github.com/sirupsen/logrus"
- "log"
- "net/http"
- "strings"
- "time"
-)
-
-func CodeFlow(rpc *rp.Config, key []byte, callbackPath string, port string) *oidc.Tokens {
- cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
- provider, err := rp.NewDefaultRP(rpc, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //,
- if err != nil {
- logrus.Fatalf("error creating provider %s", err.Error())
- }
-
- return codeFlow(provider, callbackPath, port)
-}
-
-func TokenForClient(rpc *rp.Config, key []byte, token *oidc.Tokens) *http.Client {
- cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
- provider, err := rp.NewDefaultRP(rpc, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //,
- if err != nil {
- logrus.Fatalf("error creating provider %s", err.Error())
- }
-
- return provider.Client(context.Background(), token.Token)
-}
-
-func CodeFlowForClient(rpc *rp.Config, key []byte, callbackPath string, port string) *http.Client {
- cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
- provider, err := rp.NewDefaultRP(rpc, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //,
- if err != nil {
- logrus.Fatalf("error creating provider %s", err.Error())
- }
- token := codeFlow(provider, callbackPath, port)
-
- return provider.Client(context.Background(), token.Token)
-}
-
-func codeFlow(provider rp.DelegationTokenExchangeRP, callbackPath string, port string) *oidc.Tokens {
- loginPath := "/login"
- portStr := port
- if !strings.HasPrefix(port, ":") {
- portStr = strings.Join([]string{":", portStr}, "")
- }
-
- getToken, setToken := getAndSetTokens()
-
- state := uuid.New().String()
- http.Handle(loginPath, provider.AuthURLHandler(state))
-
- marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
- setToken(w, tokens)
- }
- http.Handle(callbackPath, provider.CodeExchangeHandler(marshal))
-
- // start http-server
- stopHttpServer := startHttpServer(portStr)
-
- // open browser in different window
- utils.OpenBrowser(strings.Join([]string{"http://localhost", portStr, loginPath}, ""))
-
- // wait until user is logged into browser
- ret := getToken()
-
- // stop http-server as no callback is needed anymore
- stopHttpServer()
-
- // return tokens
- return ret
-}
-
-func startHttpServer(port string) func() {
- srv := &http.Server{Addr: port}
- go func() {
-
- // always returns error. ErrServerClosed on graceful close
- if err := srv.ListenAndServe(); err != http.ErrServerClosed {
- // unexpected error. port in use?
- log.Fatalf("ListenAndServe(): %v", err)
- }
- }()
-
- return func() {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- if err := srv.Shutdown(ctx); err != nil {
- log.Fatalf("Shutdown(): %v", err)
- }
- }
-}
-
-func getAndSetTokens() (func() *oidc.Tokens, func(w http.ResponseWriter, tokens *oidc.Tokens)) {
- marshalChan := make(chan *oidc.Tokens)
-
- getToken := func() *oidc.Tokens {
- return <-marshalChan
- }
- setToken := func(w http.ResponseWriter, tokens *oidc.Tokens) {
- marshalChan <- tokens
-
- msg := "Success!
"
- msg = msg + "You are authenticated and can now return to the CLI.
"
- fmt.Fprintf(w, msg)
- }
-
- return getToken, setToken
-}
diff --git a/pkg/client/client.go b/pkg/client/client.go
new file mode 100644
index 0000000..2e1f536
--- /dev/null
+++ b/pkg/client/client.go
@@ -0,0 +1,312 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/go-jose/go-jose/v4"
+ "github.com/zitadel/logging"
+ "go.opentelemetry.io/otel"
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+var (
+ Encoder = httphelper.Encoder(oidc.NewEncoder())
+ Tracer = otel.Tracer("github.com/zitadel/oidc/pkg/client")
+)
+
+// Discover calls the discovery endpoint of the provided issuer and returns its configuration
+// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
+func Discover(ctx context.Context, issuer string, httpClient *http.Client, wellKnownUrl ...string) (*oidc.DiscoveryConfiguration, error) {
+ ctx, span := Tracer.Start(ctx, "Discover")
+ defer span.End()
+
+ wellKnown := strings.TrimSuffix(issuer, "/") + oidc.DiscoveryEndpoint
+ if len(wellKnownUrl) == 1 && wellKnownUrl[0] != "" {
+ wellKnown = wellKnownUrl[0]
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
+ if err != nil {
+ return nil, err
+ }
+ discoveryConfig := new(oidc.DiscoveryConfiguration)
+ err = httphelper.HttpRequest(httpClient, req, &discoveryConfig)
+ if err != nil {
+ return nil, errors.Join(oidc.ErrDiscoveryFailed, err)
+ }
+ if logger, ok := logging.FromContext(ctx); ok {
+ logger.Debug("discover", "config", discoveryConfig)
+ }
+
+ if discoveryConfig.Issuer != issuer {
+ return nil, oidc.ErrIssuerInvalid
+ }
+ return discoveryConfig, nil
+}
+
+type TokenEndpointCaller interface {
+ TokenEndpoint() string
+ HttpClient() *http.Client
+}
+
+func CallTokenEndpoint(ctx context.Context, request any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
+ return callTokenEndpoint(ctx, request, nil, caller)
+}
+
+func callTokenEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (newToken *oauth2.Token, err error) {
+ ctx, span := Tracer.Start(ctx, "callTokenEndpoint")
+ defer span.End()
+
+ req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
+ if err != nil {
+ return nil, err
+ }
+ tokenRes := new(oidc.AccessTokenResponse)
+ if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
+ return nil, err
+ }
+ token := &oauth2.Token{
+ AccessToken: tokenRes.AccessToken,
+ TokenType: tokenRes.TokenType,
+ RefreshToken: tokenRes.RefreshToken,
+ Expiry: time.Now().UTC().Add(time.Duration(tokenRes.ExpiresIn) * time.Second),
+ }
+ if tokenRes.IDToken != "" {
+ token = token.WithExtra(map[string]any{
+ "id_token": tokenRes.IDToken,
+ })
+ }
+ return token, nil
+}
+
+type EndSessionCaller interface {
+ GetEndSessionEndpoint() string
+ HttpClient() *http.Client
+}
+
+func CallEndSessionEndpoint(ctx context.Context, request any, authFn any, caller EndSessionCaller) (*url.URL, error) {
+ ctx, span := Tracer.Start(ctx, "CallEndSessionEndpoint")
+ defer span.End()
+
+ endpoint := caller.GetEndSessionEndpoint()
+ if endpoint == "" {
+ return nil, fmt.Errorf("end session %w", ErrEndpointNotSet)
+ }
+
+ req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn)
+ if err != nil {
+ return nil, err
+ }
+ client := caller.HttpClient()
+ client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 400 {
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ return nil, fmt.Errorf("EndSession failure, %d status code: %s", resp.StatusCode, string(body))
+ }
+ location, err := resp.Location()
+ if err != nil {
+ if errors.Is(err, http.ErrNoLocation) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return location, nil
+}
+
+type RevokeCaller interface {
+ GetRevokeEndpoint() string
+ HttpClient() *http.Client
+}
+
+type RevokeRequest struct {
+ Token string `schema:"token"`
+ TokenTypeHint string `schema:"token_type_hint"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"`
+}
+
+func CallRevokeEndpoint(ctx context.Context, request any, authFn any, caller RevokeCaller) error {
+ ctx, span := Tracer.Start(ctx, "CallRevokeEndpoint")
+ defer span.End()
+
+ endpoint := caller.GetRevokeEndpoint()
+ if endpoint == "" {
+ return fmt.Errorf("revoke %w", ErrEndpointNotSet)
+ }
+
+ req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn)
+ if err != nil {
+ return err
+ }
+ client := caller.HttpClient()
+ client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ // According to RFC7009 in section 2.2:
+ // "The content of the response body is ignored by the client as all
+ // necessary information is conveyed in the response code."
+ if resp.StatusCode != 200 {
+ body, err := io.ReadAll(resp.Body)
+ if err == nil {
+ return fmt.Errorf("revoke returned status %d and text: %s", resp.StatusCode, string(body))
+ } else {
+ return fmt.Errorf("revoke returned status %d", resp.StatusCode)
+ }
+ }
+ return nil
+}
+
+func CallTokenExchangeEndpoint(ctx context.Context, request any, authFn any, caller TokenEndpointCaller) (resp *oidc.TokenExchangeResponse, err error) {
+ ctx, span := Tracer.Start(ctx, "CallTokenExchangeEndpoint")
+ defer span.End()
+
+ req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, authFn)
+ if err != nil {
+ return nil, err
+ }
+ tokenRes := new(oidc.TokenExchangeResponse)
+ if err := httphelper.HttpRequest(caller.HttpClient(), req, &tokenRes); err != nil {
+ return nil, err
+ }
+ return tokenRes, nil
+}
+
+func NewSignerFromPrivateKeyByte(key []byte, keyID string) (jose.Signer, error) {
+ privateKey, algorithm, err := crypto.BytesToPrivateKey(key)
+ if err != nil {
+ return nil, err
+ }
+ signingKey := jose.SigningKey{
+ Algorithm: algorithm,
+ Key: &jose.JSONWebKey{Key: privateKey, KeyID: keyID},
+ }
+ return jose.NewSigner(signingKey, &jose.SignerOptions{})
+}
+
+func SignedJWTProfileAssertion(clientID string, audience []string, expiration time.Duration, signer jose.Signer) (string, error) {
+ iat := time.Now()
+ exp := iat.Add(expiration)
+ return crypto.Sign(&oidc.JWTTokenRequest{
+ Issuer: clientID,
+ Subject: clientID,
+ Audience: audience,
+ ExpiresAt: oidc.FromTime(exp),
+ IssuedAt: oidc.FromTime(iat),
+ }, signer)
+}
+
+type DeviceAuthorizationCaller interface {
+ GetDeviceAuthorizationEndpoint() string
+ HttpClient() *http.Client
+}
+
+func CallDeviceAuthorizationEndpoint(ctx context.Context, request *oidc.ClientCredentialsRequest, caller DeviceAuthorizationCaller, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
+ ctx, span := Tracer.Start(ctx, "CallDeviceAuthorizationEndpoint")
+ defer span.End()
+
+ endpoint := caller.GetDeviceAuthorizationEndpoint()
+ if endpoint == "" {
+ return nil, fmt.Errorf("device authorization %w", ErrEndpointNotSet)
+ }
+
+ req, err := httphelper.FormRequest(ctx, endpoint, request, Encoder, authFn)
+ if err != nil {
+ return nil, err
+ }
+ if request.ClientSecret != "" {
+ req.SetBasicAuth(request.ClientID, request.ClientSecret)
+ }
+
+ resp := new(oidc.DeviceAuthorizationResponse)
+ if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+type DeviceAccessTokenRequest struct {
+ *oidc.ClientCredentialsRequest
+ oidc.DeviceAccessTokenRequest
+}
+
+func CallDeviceAccessTokenEndpoint(ctx context.Context, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
+ ctx, span := Tracer.Start(ctx, "CallDeviceAccessTokenEndpoint")
+ defer span.End()
+
+ req, err := httphelper.FormRequest(ctx, caller.TokenEndpoint(), request, Encoder, nil)
+ if err != nil {
+ return nil, err
+ }
+ if request.ClientSecret != "" {
+ req.SetBasicAuth(request.ClientID, request.ClientSecret)
+ }
+
+ resp := new(oidc.AccessTokenResponse)
+ if err := httphelper.HttpRequest(caller.HttpClient(), req, &resp); err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func PollDeviceAccessTokenEndpoint(ctx context.Context, interval time.Duration, request *DeviceAccessTokenRequest, caller TokenEndpointCaller) (*oidc.AccessTokenResponse, error) {
+ ctx, span := Tracer.Start(ctx, "PollDeviceAccessTokenEndpoint")
+ defer span.End()
+
+ for {
+ timer := time.After(interval)
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer:
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, interval)
+ defer cancel()
+
+ resp, err := CallDeviceAccessTokenEndpoint(ctx, request, caller)
+ if err == nil {
+ return resp, nil
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ interval += 5 * time.Second
+ }
+ var target *oidc.Error
+ if !errors.As(err, &target) {
+ return nil, err
+ }
+ switch target.ErrorType {
+ case oidc.AuthorizationPending:
+ continue
+ case oidc.SlowDown:
+ interval += 5 * time.Second
+ continue
+ default:
+ return nil, err
+ }
+ }
+}
diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go
new file mode 100644
index 0000000..9e21e8e
--- /dev/null
+++ b/pkg/client/client_test.go
@@ -0,0 +1,59 @@
+package client
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDiscover(t *testing.T) {
+ type wantFields struct {
+ UILocalesSupported bool
+ }
+
+ type args struct {
+ issuer string
+ wellKnownUrl []string
+ }
+ tests := []struct {
+ name string
+ args args
+ wantFields *wantFields
+ wantErr error
+ }{
+ {
+ name: "spotify", // https://github.com/zitadel/oidc/issues/406
+ args: args{
+ issuer: "https://accounts.spotify.com",
+ },
+ wantFields: &wantFields{
+ UILocalesSupported: true,
+ },
+ wantErr: nil,
+ },
+ {
+ name: "discovery failed",
+ args: args{
+ issuer: "https://example.com",
+ },
+ wantErr: oidc.ErrDiscoveryFailed,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := Discover(context.Background(), tt.args.issuer, http.DefaultClient, tt.args.wellKnownUrl...)
+ require.ErrorIs(t, err, tt.wantErr)
+ if tt.wantFields == nil {
+ return
+ }
+ assert.Equal(t, tt.args.issuer, got.Issuer)
+ if tt.wantFields.UILocalesSupported {
+ assert.NotEmpty(t, got.UILocalesSupported)
+ }
+ })
+ }
+}
diff --git a/pkg/client/errors.go b/pkg/client/errors.go
new file mode 100644
index 0000000..47210e5
--- /dev/null
+++ b/pkg/client/errors.go
@@ -0,0 +1,5 @@
+package client
+
+import "errors"
+
+var ErrEndpointNotSet = errors.New("endpoint not set")
diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go
new file mode 100644
index 0000000..86a9ab7
--- /dev/null
+++ b/pkg/client/integration_test.go
@@ -0,0 +1,594 @@
+package client_test
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "math/rand"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "os/signal"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/jeremija/gosubmit"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/exampleop"
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/tokenexchange"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+var Logger = slog.New(
+ slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+ AddSource: true,
+ Level: slog.LevelDebug,
+ }),
+)
+
+var CTX context.Context
+
+func TestMain(m *testing.M) {
+ os.Exit(func() int {
+ ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGINT)
+ defer cancel()
+ CTX, cancel = context.WithTimeout(ctx, time.Minute)
+ defer cancel()
+ return m.Run()
+ }())
+}
+
+func TestRelyingPartySession(t *testing.T) {
+ for _, wrapServer := range []bool{false, true} {
+ t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
+ testRelyingPartySession(t, wrapServer)
+ })
+ }
+}
+
+func testRelyingPartySession(t *testing.T, wrapServer bool) {
+ t.Log("------- start example OP ------")
+ targetURL := "http://local-site"
+ exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
+ var dh deferredHandler
+ opServer := httptest.NewServer(&dh)
+ defer opServer.Close()
+ t.Logf("auth server at %s", opServer.URL)
+ dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
+
+ seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
+ clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
+
+ t.Log("------- run authorization code flow ------")
+ provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret")
+
+ t.Log("------- refresh tokens ------")
+
+ newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
+ require.NoError(t, err, "refresh token")
+ assert.NotNil(t, newTokens, "access token")
+ t.Logf("new access token %s", newTokens.AccessToken)
+ t.Logf("new refresh token %s", newTokens.RefreshToken)
+ t.Logf("new token type %s", newTokens.TokenType)
+ t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339))
+ require.NotEmpty(t, newTokens.AccessToken, "new accessToken")
+ assert.NotEmpty(t, newTokens.IDToken, "new idToken")
+ assert.NotNil(t, newTokens.IDTokenClaims)
+ assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject)
+
+ t.Log("------ end session (logout) ------")
+
+ newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
+ require.NoError(t, err, "logout")
+ if newLoc != nil {
+ t.Logf("redirect to %s", newLoc)
+ } else {
+ t.Logf("no redirect")
+ }
+
+ t.Log("------ attempt refresh again (should fail) ------")
+ t.Log("trying original refresh token", tokens.RefreshToken)
+ _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
+ assert.Errorf(t, err, "refresh with original")
+ if newTokens.RefreshToken != "" {
+ t.Log("trying replacement refresh token", newTokens.RefreshToken)
+ _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "")
+ assert.Errorf(t, err, "refresh with replacement")
+ }
+}
+
+func TestRelyingPartyWithSigningAlgsFromDiscovery(t *testing.T) {
+ targetURL := "http://local-site"
+ localURL, err := url.Parse(targetURL + "/login?requestID=1234")
+ require.NoError(t, err, "local url")
+
+ t.Log("------- start example OP ------")
+ seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
+ clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
+ clientSecret := "secret"
+ client := storage.WebClient(clientID, clientSecret, targetURL)
+ storage.RegisterClients(client)
+ exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
+ var dh deferredHandler
+ opServer := httptest.NewServer(&dh)
+ defer opServer.Close()
+ dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true)
+
+ t.Log("------- create RP ------")
+ provider, err := rp.NewRelyingPartyOIDC(
+ CTX,
+ opServer.URL,
+ clientID,
+ clientSecret,
+ targetURL,
+ []string{"openid"},
+ rp.WithSigningAlgsFromDiscovery(),
+ )
+ require.NoError(t, err, "new rp")
+
+ t.Log("------- run authorization code flow ------")
+ jar, err := cookiejar.New(nil)
+ require.NoError(t, err, "create cookie jar")
+ httpClient := &http.Client{
+ Timeout: time.Second * 5,
+ CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ Jar: jar,
+ }
+ state := "state-" + strconv.FormatInt(seed.Int63(), 25)
+ capturedW := httptest.NewRecorder()
+ get := httptest.NewRequest("GET", localURL.String(), nil)
+ rp.AuthURLHandler(func() string { return state }, provider,
+ rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
+ rp.WithURLParam("custom", "param"),
+ )(capturedW, get)
+ defer func() {
+ if t.Failed() {
+ t.Log("response body (redirect from RP to OP)", capturedW.Body.String())
+ }
+ }()
+ resp := capturedW.Result()
+ startAuthURL, err := resp.Location()
+ require.NoError(t, err, "get redirect")
+ loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL)
+ form := getForm(t, "get login form", httpClient, loginPageURL)
+ defer func() {
+ if t.Failed() {
+ t.Logf("login form (unfilled): %s", string(form))
+ }
+ }()
+ postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
+ gosubmit.Set("username", "test-user@local-site"),
+ gosubmit.Set("password", "verysecure"),
+ )
+ codeBearingURL := getRedirect(t, "get redirect with code", httpClient, postLoginRedirectURL)
+ capturedW = httptest.NewRecorder()
+ get = httptest.NewRequest("GET", codeBearingURL.String(), nil)
+ var idToken string
+ redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
+ idToken = newTokens.IDToken
+ http.Redirect(w, r, targetURL, http.StatusFound)
+ }
+ rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider)(capturedW, get)
+ defer func() {
+ if t.Failed() {
+ t.Log("token exchange response body", capturedW.Body.String())
+ require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
+ }
+ }()
+
+ t.Log("------- verify id token ------")
+ _, err = rp.VerifyIDToken[*oidc.IDTokenClaims](CTX, idToken, provider.IDTokenVerifier())
+ require.NoError(t, err, "verify id token")
+}
+
+func TestResourceServerTokenExchange(t *testing.T) {
+ for _, wrapServer := range []bool{false, true} {
+ t.Run(fmt.Sprint("wrapServer ", wrapServer), func(t *testing.T) {
+ testResourceServerTokenExchange(t, wrapServer)
+ })
+ }
+}
+
+func testResourceServerTokenExchange(t *testing.T, wrapServer bool) {
+ t.Log("------- start example OP ------")
+ targetURL := "http://local-site"
+ exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
+ var dh deferredHandler
+ opServer := httptest.NewServer(&dh)
+ defer opServer.Close()
+ t.Logf("auth server at %s", opServer.URL)
+ dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, wrapServer)
+
+ seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
+ clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
+ clientSecret := "secret"
+
+ t.Log("------- run authorization code flow ------")
+ provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret)
+
+ resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret)
+ require.NoError(t, err, "new resource server")
+
+ t.Log("------- exchage refresh tokens (impersonation) ------")
+
+ tokenExchangeResponse, err := tokenexchange.ExchangeToken(
+ CTX,
+ resourceServer,
+ tokens.RefreshToken,
+ oidc.RefreshTokenType,
+ "",
+ "",
+ []string{},
+ []string{},
+ []string{"profile", "custom_scope:impersonate:id2"},
+ oidc.RefreshTokenType,
+ )
+ require.NoError(t, err, "refresh token")
+ require.NotNil(t, tokenExchangeResponse, "token exchange response")
+ assert.Equal(t, tokenExchangeResponse.IssuedTokenType, oidc.RefreshTokenType)
+ assert.NotEmpty(t, tokenExchangeResponse.AccessToken, "access token")
+ assert.NotEmpty(t, tokenExchangeResponse.RefreshToken, "refresh token")
+ assert.Equal(t, []string(tokenExchangeResponse.Scopes), []string{"profile", "custom_scope:impersonate:id2"})
+
+ t.Log("------ end session (logout) ------")
+
+ newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "")
+ require.NoError(t, err, "logout")
+ if newLoc != nil {
+ t.Logf("redirect to %s", newLoc)
+ } else {
+ t.Logf("no redirect")
+ }
+
+ t.Log("------- attempt exchage again (should fail) ------")
+
+ tokenExchangeResponse, err = tokenexchange.ExchangeToken(
+ CTX,
+ resourceServer,
+ tokens.RefreshToken,
+ oidc.RefreshTokenType,
+ "",
+ "",
+ []string{},
+ []string{},
+ []string{"profile", "custom_scope:impersonate:id2"},
+ oidc.RefreshTokenType,
+ )
+ require.Error(t, err, "refresh token")
+ assert.Contains(t, err.Error(), "subject_token is invalid")
+ require.Nil(t, tokenExchangeResponse, "token exchange response")
+}
+
+func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) {
+ targetURL := "http://local-site"
+ localURL, err := url.Parse(targetURL + "/login?requestID=1234")
+ require.NoError(t, err, "local url")
+
+ client := storage.WebClient(clientID, clientSecret, targetURL)
+ storage.RegisterClients(client)
+
+ jar, err := cookiejar.New(nil)
+ require.NoError(t, err, "create cookie jar")
+ httpClient := &http.Client{
+ Timeout: time.Second * 5,
+ CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ Jar: jar,
+ }
+
+ t.Log("------- create RP ------")
+ key := []byte("test1234test1234")
+ cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
+ provider, err = rp.NewRelyingPartyOIDC(
+ CTX,
+ opServer.URL,
+ clientID,
+ clientSecret,
+ targetURL,
+ []string{"openid", "email", "profile", "offline_access"},
+ rp.WithPKCE(cookieHandler),
+ rp.WithAuthStyle(oauth2.AuthStyleInHeader),
+ rp.WithVerifierOpts(
+ rp.WithIssuedAtOffset(5*time.Second),
+ rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
+ ),
+ )
+ require.NoError(t, err, "new rp")
+
+ t.Log("------- get redirect from local client (rp) to OP ------")
+ seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
+ state := "state-" + strconv.FormatInt(seed.Int63(), 25)
+ capturedW := httptest.NewRecorder()
+ get := httptest.NewRequest("GET", localURL.String(), nil)
+ rp.AuthURLHandler(func() string { return state }, provider,
+ rp.WithPromptURLParam("Hello, World!", "Goodbye, World!"),
+ rp.WithURLParam("custom", "param"),
+ )(capturedW, get)
+
+ defer func() {
+ if t.Failed() {
+ t.Log("response body (redirect from RP to OP)", capturedW.Body.String())
+ }
+ }()
+ require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
+ require.Less(t, capturedW.Code, 400, "captured response code")
+ require.Contains(t, capturedW.Body.String(), `prompt=Hello%2C+World%21+Goodbye%2C+World%21`)
+ require.Contains(t, capturedW.Body.String(), `custom=param`)
+
+ //nolint:bodyclose
+ resp := capturedW.Result()
+ jar.SetCookies(localURL, resp.Cookies())
+
+ startAuthURL, err := resp.Location()
+ require.NoError(t, err, "get redirect")
+ assert.NotEmpty(t, startAuthURL, "login url")
+ t.Log("Starting auth at", startAuthURL)
+
+ t.Log("------- get redirect to OP to login page ------")
+ loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL)
+ t.Log("login page URL", loginPageURL)
+
+ t.Log("------- get login form ------")
+ form := getForm(t, "get login form", httpClient, loginPageURL)
+ t.Log("login form (unfilled)", string(form))
+ defer func() {
+ if t.Failed() {
+ t.Logf("login form (unfilled): %s", string(form))
+ }
+ }()
+
+ t.Log("------- post to login form, get redirect to OP ------")
+ postLoginRedirectURL := fillForm(t, "fill login form", httpClient, form, loginPageURL,
+ gosubmit.Set("username", "test-user@local-site"),
+ gosubmit.Set("password", "verysecure"))
+ t.Logf("Get redirect from %s", postLoginRedirectURL)
+
+ t.Log("------- redirect from OP back to RP ------")
+ codeBearingURL := getRedirect(t, "get redirect with code", httpClient, postLoginRedirectURL)
+ t.Logf("Redirect with code %s", codeBearingURL)
+
+ t.Log("------- exchange code for tokens ------")
+ capturedW = httptest.NewRecorder()
+ get = httptest.NewRequest("GET", codeBearingURL.String(), nil)
+ for _, cookie := range jar.Cookies(codeBearingURL) {
+ get.Header["Cookie"] = append(get.Header["Cookie"], cookie.String())
+ t.Logf("setting cookie %s", cookie)
+ }
+
+ var email string
+ redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) {
+ tokens = newTokens
+ require.NotNil(t, tokens, "tokens")
+ require.NotNil(t, info, "info")
+ t.Log("access token", tokens.AccessToken)
+ t.Log("refresh token", tokens.RefreshToken)
+ t.Log("id token", tokens.IDToken)
+ t.Log("email", info.Email)
+
+ email = info.Email
+ http.Redirect(w, r, targetURL, 302)
+ }
+ rp.CodeExchangeHandler(rp.UserinfoCallback(redirect), provider, rp.WithURLParam("custom", "param"))(capturedW, get)
+
+ defer func() {
+ if t.Failed() {
+ t.Log("token exchange response body", capturedW.Body.String())
+ require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
+ }
+ }()
+ require.Less(t, capturedW.Code, 400, "token exchange response code")
+ // TODO: how to check the custom header was sent to the server?
+
+ //nolint:bodyclose
+ resp = capturedW.Result()
+
+ authorizedURL, err := resp.Location()
+ require.NoError(t, err, "get fully-authorizied redirect location")
+ require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location")
+
+ require.NotEmpty(t, tokens.IDToken, "id token")
+ assert.NotEmpty(t, tokens.RefreshToken, "refresh token")
+ assert.NotEmpty(t, tokens.AccessToken, "access token")
+ assert.NotEmpty(t, email, "email")
+
+ return provider, tokens
+}
+
+func TestClientCredentials(t *testing.T) {
+ targetURL := "http://local-site"
+ exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
+ var dh deferredHandler
+ opServer := httptest.NewServer(&dh)
+ defer opServer.Close()
+ t.Logf("auth server at %s", opServer.URL)
+ dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true)
+
+ provider, err := rp.NewRelyingPartyOIDC(
+ CTX,
+ opServer.URL,
+ "sid1",
+ "verysecret",
+ targetURL,
+ []string{"openid"},
+ )
+ require.NoError(t, err, "new rp")
+
+ token, err := rp.ClientCredentials(CTX, provider, nil)
+ require.NoError(t, err, "ClientCredentials call")
+ require.NotNil(t, token)
+ assert.NotEmpty(t, token.AccessToken)
+}
+
+func TestErrorFromPromptNone(t *testing.T) {
+ jar, err := cookiejar.New(nil)
+ require.NoError(t, err, "create cookie jar")
+ httpClient := &http.Client{
+ Timeout: time.Second * 5,
+ CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ Jar: jar,
+ }
+
+ t.Log("------- start example OP ------")
+ targetURL := "http://local-site"
+ exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
+ var dh deferredHandler
+ opServer := httptest.NewServer(&dh)
+ defer opServer.Close()
+ t.Logf("auth server at %s", opServer.URL)
+ dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, false, op.WithHttpInterceptors(
+ func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Logf("request to %s", r.URL)
+ next.ServeHTTP(w, r)
+ })
+ },
+ ))
+ seed := rand.New(rand.NewSource(int64(os.Getpid()) + time.Now().UnixNano()))
+ clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25)
+ clientSecret := "secret"
+ client := storage.WebClient(clientID, clientSecret, targetURL)
+ storage.RegisterClients(client)
+
+ t.Log("------- create RP ------")
+ key := []byte("test1234test1234")
+ cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
+ provider, err := rp.NewRelyingPartyOIDC(
+ CTX,
+ opServer.URL,
+ clientID,
+ clientSecret,
+ targetURL,
+ []string{"openid", "email", "profile", "offline_access"},
+ rp.WithPKCE(cookieHandler),
+ rp.WithVerifierOpts(
+ rp.WithIssuedAtOffset(5*time.Second),
+ rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
+ ),
+ )
+ require.NoError(t, err, "new rp")
+
+ t.Log("------- start auth flow with prompt=none ------- ")
+ state := "state-32892"
+ capturedW := httptest.NewRecorder()
+ localURL, err := url.Parse(targetURL + "/login")
+ require.NoError(t, err)
+
+ get := httptest.NewRequest("GET", localURL.String(), nil)
+ rp.AuthURLHandler(func() string { return state }, provider,
+ rp.WithPromptURLParam("none"),
+ rp.WithResponseModeURLParam(oidc.ResponseModeFragment),
+ )(capturedW, get)
+
+ defer func() {
+ if t.Failed() {
+ t.Log("response body (redirect from RP to OP)", capturedW.Body.String())
+ }
+ }()
+ require.GreaterOrEqual(t, capturedW.Code, 200, "captured response code")
+ require.Less(t, capturedW.Code, 400, "captured response code")
+
+ //nolint:bodyclose
+ resp := capturedW.Result()
+ jar.SetCookies(localURL, resp.Cookies())
+
+ startAuthURL, err := resp.Location()
+ require.NoError(t, err, "get redirect")
+ assert.NotEmpty(t, startAuthURL, "login url")
+ t.Log("Starting auth at", startAuthURL)
+
+ t.Log("------- get redirect from OP ------")
+ loginPageURL := getRedirect(t, "get redirect to login page", httpClient, startAuthURL)
+ t.Log("login page URL", loginPageURL)
+
+ require.Contains(t, loginPageURL.String(), `error=login_required`, "prompt=none should error")
+ require.Contains(t, loginPageURL.String(), `local-site#error=`, "response_mode=fragment means '#' instead of '?'")
+}
+
+type deferredHandler struct {
+ http.Handler
+}
+
+func getRedirect(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) *url.URL {
+ req := &http.Request{
+ Method: "GET",
+ URL: uri,
+ Header: make(http.Header),
+ }
+ resp, err := httpClient.Do(req)
+ require.NoError(t, err, "GET "+uri.String())
+
+ defer func() {
+ if t.Failed() {
+ body, _ := io.ReadAll(resp.Body)
+ t.Logf("%s: GET %s: body: %s", desc, uri, string(body))
+ }
+ }()
+
+ //nolint:errcheck
+ defer resp.Body.Close()
+ redirect, err := resp.Location()
+ require.NoErrorf(t, err, "%s: get redirect %s", desc, uri)
+ require.NotEmptyf(t, redirect, "%s: get redirect %s", desc, uri)
+ return redirect
+}
+
+func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) []byte {
+ req := &http.Request{
+ Method: "GET",
+ URL: uri,
+ Header: make(http.Header),
+ }
+ resp, err := httpClient.Do(req)
+ require.NoErrorf(t, err, "%s: GET %s", desc, uri)
+ //nolint:errcheck
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err, "%s: read GET %s", desc, uri)
+ return body
+}
+
+func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL {
+ // TODO: switch to io.NopCloser when go1.15 support is dropped
+ req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest(
+ append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)...,
+ )
+ if req.URL.Scheme == "" {
+ req.URL = uri
+ t.Log("request lost it's proto..., adding back... request now", req.URL)
+ }
+ req.RequestURI = "" // bug in gosubmit?
+ resp, err := httpClient.Do(req)
+ require.NoErrorf(t, err, "%s: POST %s", desc, uri)
+
+ //nolint:errcheck
+ defer resp.Body.Close()
+ defer func() {
+ if t.Failed() {
+ body, _ := io.ReadAll(resp.Body)
+ t.Logf("%s: GET %s: body: %s", desc, uri, string(body))
+ }
+ }()
+
+ redirect, err := resp.Location()
+ require.NoErrorf(t, err, "%s: redirect for POST %s", desc, uri)
+ return redirect
+}
diff --git a/pkg/client/jwt_profile.go b/pkg/client/jwt_profile.go
new file mode 100644
index 0000000..98a54fd
--- /dev/null
+++ b/pkg/client/jwt_profile.go
@@ -0,0 +1,30 @@
+package client
+
+import (
+ "context"
+ "net/url"
+
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// JWTProfileExchange handles the oauth2 jwt profile exchange
+func JWTProfileExchange(ctx context.Context, jwtProfileGrantRequest *oidc.JWTProfileGrantRequest, caller TokenEndpointCaller) (*oauth2.Token, error) {
+ return CallTokenEndpoint(ctx, jwtProfileGrantRequest, caller)
+}
+
+func ClientAssertionCodeOptions(assertion string) []oauth2.AuthCodeOption {
+ return []oauth2.AuthCodeOption{
+ oauth2.SetAuthURLParam("client_assertion", assertion),
+ oauth2.SetAuthURLParam("client_assertion_type", oidc.ClientAssertionTypeJWTAssertion),
+ }
+}
+
+func ClientAssertionFormAuthorization(assertion string) http.FormAuthorization {
+ return func(values url.Values) {
+ values.Set("client_assertion", assertion)
+ values.Set("client_assertion_type", oidc.ClientAssertionTypeJWTAssertion)
+ }
+}
diff --git a/pkg/client/key.go b/pkg/client/key.go
new file mode 100644
index 0000000..7f38311
--- /dev/null
+++ b/pkg/client/key.go
@@ -0,0 +1,40 @@
+package client
+
+import (
+ "encoding/json"
+ "os"
+)
+
+const (
+ serviceAccountKey = "serviceaccount"
+ applicationKey = "application"
+)
+
+type KeyFile struct {
+ Type string `json:"type"` // serviceaccount or application
+ KeyID string `json:"keyId"`
+ Key string `json:"key"`
+ Issuer string `json:"issuer"` // not yet in file
+
+ // serviceaccount
+ UserID string `json:"userId"`
+
+ // application
+ ClientID string `json:"clientId"`
+}
+
+func ConfigFromKeyFile(path string) (*KeyFile, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return ConfigFromKeyFileData(data)
+}
+
+func ConfigFromKeyFileData(data []byte) (*KeyFile, error) {
+ var f KeyFile
+ if err := json.Unmarshal(data, &f); err != nil {
+ return nil, err
+ }
+ return &f, nil
+}
diff --git a/pkg/client/profile/jwt_profile.go b/pkg/client/profile/jwt_profile.go
new file mode 100644
index 0000000..fb351f0
--- /dev/null
+++ b/pkg/client/profile/jwt_profile.go
@@ -0,0 +1,118 @@
+package profile
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type TokenSource interface {
+ oauth2.TokenSource
+ TokenCtx(context.Context) (*oauth2.Token, error)
+}
+
+// jwtProfileTokenSource implement the oauth2.TokenSource
+// it will request a token using the OAuth2 JWT Profile Grant
+// therefore sending an `assertion` by signing a JWT with the provided private key
+type jwtProfileTokenSource struct {
+ clientID string
+ audience []string
+ signer jose.Signer
+ scopes []string
+ httpClient *http.Client
+ tokenEndpoint string
+}
+
+// NewJWTProfileTokenSourceFromKeyFile returns an implementation of TokenSource
+// It will request a token using the OAuth2 JWT Profile Grant,
+// therefore sending an `assertion` by singing a JWT with the provided private key from jsonFile.
+//
+// The passed context is only used for the call to the Discover endpoint.
+func NewJWTProfileTokenSourceFromKeyFile(ctx context.Context, issuer, jsonFile string, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
+ keyData, err := client.ConfigFromKeyFile(jsonFile)
+ if err != nil {
+ return nil, err
+ }
+ return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
+}
+
+// NewJWTProfileTokenSourceFromKeyFileData returns an implementation of oauth2.TokenSource
+// It will request a token using the OAuth2 JWT Profile Grant,
+// therefore sending an `assertion` by singing a JWT with the provided private key in jsonData.
+//
+// The passed context is only used for the call to the Discover endpoint.
+func NewJWTProfileTokenSourceFromKeyFileData(ctx context.Context, issuer string, jsonData []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
+ keyData, err := client.ConfigFromKeyFileData(jsonData)
+ if err != nil {
+ return nil, err
+ }
+ return NewJWTProfileTokenSource(ctx, issuer, keyData.UserID, keyData.KeyID, []byte(keyData.Key), scopes, options...)
+}
+
+// NewJWTProfileSource returns an implementation of oauth2.TokenSource
+// It will request a token using the OAuth2 JWT Profile Grant,
+// therefore sending an `assertion` by singing a JWT with the provided private key.
+//
+// The passed context is only used for the call to the Discover endpoint.
+func NewJWTProfileTokenSource(ctx context.Context, issuer, clientID, keyID string, key []byte, scopes []string, options ...func(source *jwtProfileTokenSource)) (TokenSource, error) {
+ signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
+ if err != nil {
+ return nil, err
+ }
+ source := &jwtProfileTokenSource{
+ clientID: clientID,
+ audience: []string{issuer},
+ signer: signer,
+ scopes: scopes,
+ httpClient: http.DefaultClient,
+ }
+ for _, opt := range options {
+ opt(source)
+ }
+ if source.tokenEndpoint == "" {
+ config, err := client.Discover(ctx, issuer, source.httpClient)
+ if err != nil {
+ return nil, err
+ }
+ source.tokenEndpoint = config.TokenEndpoint
+ }
+ return source, nil
+}
+
+func WithHTTPClient(client *http.Client) func(source *jwtProfileTokenSource) {
+ return func(source *jwtProfileTokenSource) {
+ source.httpClient = client
+ }
+}
+
+func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(source *jwtProfileTokenSource) {
+ return func(source *jwtProfileTokenSource) {
+ source.tokenEndpoint = tokenEndpoint
+ }
+}
+
+func (j *jwtProfileTokenSource) TokenEndpoint() string {
+ return j.tokenEndpoint
+}
+
+func (j *jwtProfileTokenSource) HttpClient() *http.Client {
+ return j.httpClient
+}
+
+func (j *jwtProfileTokenSource) Token() (*oauth2.Token, error) {
+ return j.TokenCtx(context.Background())
+}
+
+func (j *jwtProfileTokenSource) TokenCtx(ctx context.Context) (*oauth2.Token, error) {
+ assertion, err := client.SignedJWTProfileAssertion(j.clientID, j.audience, time.Hour, j.signer)
+ if err != nil {
+ return nil, err
+ }
+ return client.JWTProfileExchange(ctx, oidc.NewJWTProfileGrantRequest(assertion, j.scopes...), j)
+}
diff --git a/pkg/utils/browser.go b/pkg/client/rp/cli/browser.go
similarity index 96%
rename from pkg/utils/browser.go
rename to pkg/client/rp/cli/browser.go
index dca75e4..1948427 100644
--- a/pkg/utils/browser.go
+++ b/pkg/client/rp/cli/browser.go
@@ -1,4 +1,4 @@
-package utils
+package cli
import (
"fmt"
diff --git a/pkg/client/rp/cli/cli.go b/pkg/client/rp/cli/cli.go
new file mode 100644
index 0000000..10edaa7
--- /dev/null
+++ b/pkg/client/rp/cli/cli.go
@@ -0,0 +1,36 @@
+package cli
+
+import (
+ "context"
+ "net/http"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+const (
+ loginPath = "/login"
+)
+
+func CodeFlow[C oidc.IDClaims](ctx context.Context, relyingParty rp.RelyingParty, callbackPath, port string, stateProvider func() string) *oidc.Tokens[C] {
+ codeflowCtx, codeflowCancel := context.WithCancel(ctx)
+ defer codeflowCancel()
+
+ tokenChan := make(chan *oidc.Tokens[C], 1)
+
+ callback := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp rp.RelyingParty) {
+ tokenChan <- tokens
+ msg := "Success!
"
+ msg = msg + "You are authenticated and can now return to the CLI.
"
+ w.Write([]byte(msg))
+ }
+ http.Handle(loginPath, rp.AuthURLHandler(stateProvider, relyingParty))
+ http.Handle(callbackPath, rp.CodeExchangeHandler(callback, relyingParty))
+
+ httphelper.StartServer(codeflowCtx, ":"+port)
+
+ OpenBrowser("http://localhost:" + port + loginPath)
+
+ return <-tokenChan
+}
diff --git a/pkg/client/rp/delegation.go b/pkg/client/rp/delegation.go
new file mode 100644
index 0000000..fb4fc63
--- /dev/null
+++ b/pkg/client/rp/delegation.go
@@ -0,0 +1,13 @@
+package rp
+
+import (
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc/grants/tokenexchange"
+)
+
+// DelegationTokenRequest is an implementation of TokenExchangeRequest
+// it exchanges an "urn:ietf:params:oauth:token-type:access_token" with an optional
+// "urn:ietf:params:oauth:token-type:access_token" actor token for an
+// "urn:ietf:params:oauth:token-type:access_token" delegation token
+func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
+ return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
+}
diff --git a/pkg/client/rp/device.go b/pkg/client/rp/device.go
new file mode 100644
index 0000000..1fadd56
--- /dev/null
+++ b/pkg/client/rp/device.go
@@ -0,0 +1,69 @@
+package rp
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+func newDeviceClientCredentialsRequest(scopes []string, rp RelyingParty) (*oidc.ClientCredentialsRequest, error) {
+ confg := rp.OAuthConfig()
+ req := &oidc.ClientCredentialsRequest{
+ Scope: scopes,
+ ClientID: confg.ClientID,
+ ClientSecret: confg.ClientSecret,
+ }
+
+ if signer := rp.Signer(); signer != nil {
+ assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, signer)
+ if err != nil {
+ return nil, fmt.Errorf("failed to build assertion: %w", err)
+ }
+ req.ClientAssertion = assertion
+ req.ClientAssertionType = oidc.ClientAssertionTypeJWTAssertion
+ }
+
+ return req, nil
+}
+
+// DeviceAuthorization starts a new Device Authorization flow as defined
+// in RFC 8628, section 3.1 and 3.2:
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.1
+func DeviceAuthorization(ctx context.Context, scopes []string, rp RelyingParty, authFn any) (*oidc.DeviceAuthorizationResponse, error) {
+ ctx, span := client.Tracer.Start(ctx, "DeviceAuthorization")
+ defer span.End()
+
+ ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAuthorization")
+ req, err := newDeviceClientCredentialsRequest(scopes, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.CallDeviceAuthorizationEndpoint(ctx, req, rp, authFn)
+}
+
+// DeviceAccessToken attempts to obtain tokens from a Device Authorization,
+// by means of polling as defined in RFC, section 3.3 and 3.4:
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.4
+func DeviceAccessToken(ctx context.Context, deviceCode string, interval time.Duration, rp RelyingParty) (resp *oidc.AccessTokenResponse, err error) {
+ ctx, span := client.Tracer.Start(ctx, "DeviceAccessToken")
+ defer span.End()
+
+ ctx = logCtxWithRPData(ctx, rp, "function", "DeviceAccessToken")
+ req := &client.DeviceAccessTokenRequest{
+ DeviceAccessTokenRequest: oidc.DeviceAccessTokenRequest{
+ GrantType: oidc.GrantTypeDeviceCode,
+ DeviceCode: deviceCode,
+ },
+ }
+
+ req.ClientCredentialsRequest, err = newDeviceClientCredentialsRequest(nil, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ return client.PollDeviceAccessTokenEndpoint(ctx, interval, req, tokenEndpointCaller{rp})
+}
diff --git a/pkg/client/rp/errors.go b/pkg/client/rp/errors.go
new file mode 100644
index 0000000..b95420b
--- /dev/null
+++ b/pkg/client/rp/errors.go
@@ -0,0 +1,5 @@
+package rp
+
+import "errors"
+
+var ErrRelyingPartyNotSupportRevokeCaller = errors.New("RelyingParty does not support RevokeCaller")
diff --git a/pkg/client/rp/jwks.go b/pkg/client/rp/jwks.go
new file mode 100644
index 0000000..0ccbad2
--- /dev/null
+++ b/pkg/client/rp/jwks.go
@@ -0,0 +1,255 @@
+package rp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "sync"
+
+ jose "github.com/go-jose/go-jose/v4"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
+ keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
+ for _, opt := range opts {
+ opt(keyset)
+ }
+ return keyset
+}
+
+// SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
+// and no kid header is set in the JWT
+//
+// this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
+// there is only a single remote key
+// please notice that remote keys will then only be fetched if cached keys are empty
+func SkipRemoteCheck() func(set *remoteKeySet) {
+ return func(set *remoteKeySet) {
+ set.skipRemoteCheck = true
+ }
+}
+
+type remoteKeySet struct {
+ jwksURL string
+ httpClient *http.Client
+ defaultAlg string
+ skipRemoteCheck bool
+
+ // guard all other fields
+ mu sync.Mutex
+
+ // inflight suppresses parallel execution of updateKeys and allows
+ // multiple goroutines to wait for its result.
+ inflight *inflight
+
+ // A set of cached keys and their expiry.
+ cachedKeys []jose.JSONWebKey
+}
+
+// inflight is used to wait on some in-flight request from multiple goroutines.
+type inflight struct {
+ doneCh chan struct{}
+
+ keys []jose.JSONWebKey
+ err error
+}
+
+func newInflight() *inflight {
+ return &inflight{doneCh: make(chan struct{})}
+}
+
+// wait returns a channel that multiple goroutines can receive on. Once it returns
+// a value, the inflight request is done and result() can be inspected.
+func (i *inflight) wait() <-chan struct{} {
+ return i.doneCh
+}
+
+// done can only be called by a single goroutine. It records the result of the
+// inflight request and signals other goroutines that the result is safe to
+// inspect.
+func (i *inflight) done(keys []jose.JSONWebKey, err error) {
+ i.keys = keys
+ i.err = err
+ close(i.doneCh)
+}
+
+// result cannot be called until the wait() channel has returned a value.
+func (i *inflight) result() ([]jose.JSONWebKey, error) {
+ return i.keys, i.err
+}
+
+func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
+ ctx, span := client.Tracer.Start(ctx, "VerifySignature")
+ defer span.End()
+
+ keyID, alg := oidc.GetKeyIDAndAlg(jws)
+ if alg == "" {
+ alg = r.defaultAlg
+ }
+ payload, err := r.verifySignatureCached(jws, keyID, alg)
+ if payload != nil {
+ return payload, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ return r.verifySignatureRemote(ctx, jws, keyID, alg)
+}
+
+// verifySignatureCached checks for a matching key in the cached key list
+//
+// if there is only one possible, it tries to verify the signature and will return the payload if successful
+//
+// it only returns an error if signature validation fails and keys exactMatch which is if either:
+// - both kid are empty and skipRemoteCheck is set to true
+// - or both (JWT and JWK) kid are equal
+//
+// otherwise it will return no error (so remote keys will be loaded)
+func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
+ keys := r.keysFromCache()
+ if len(keys) == 0 {
+ return nil, nil
+ }
+ key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
+ if err != nil {
+ // no key / multiple found, try with remote keys
+ return nil, nil //nolint:nilerr
+ }
+ payload, err := jws.Verify(&key)
+ if payload != nil {
+ return payload, nil
+ }
+ if !r.exactMatch(key.KeyID, keyID) {
+ // no exact key match, try getting better match with remote keys
+ return nil, nil
+ }
+ return nil, fmt.Errorf("signature verification failed: %w", err)
+}
+
+func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool {
+ if jwkID == "" && jwsID == "" {
+ return r.skipRemoteCheck
+ }
+ return jwkID == jwsID
+}
+
+func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
+ ctx, span := client.Tracer.Start(ctx, "verifySignatureRemote")
+ defer span.End()
+
+ keys, err := r.keysFromRemote(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
+ }
+ key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
+ if err != nil {
+ return nil, fmt.Errorf("unable to validate signature: %w", err)
+ }
+ payload, err := jws.Verify(&key)
+ if err != nil {
+ return nil, fmt.Errorf("signature verification failed: %w", err)
+ }
+ return payload, nil
+}
+
+func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.cachedKeys
+}
+
+// keysFromRemote syncs the key set from the remote set, records the values in the
+// cache, and returns the key set.
+func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
+ ctx, span := client.Tracer.Start(ctx, "keysFromRemote")
+ defer span.End()
+
+ // Need to lock to inspect the inflight request field.
+ r.mu.Lock()
+ // If there's not a current inflight request, create one.
+ if r.inflight == nil {
+ r.inflight = newInflight()
+
+ // This goroutine has exclusive ownership over the current inflight
+ // request. It releases the resource by nil'ing the inflight field
+ // once the goroutine is done.
+ go r.updateKeys(ctx)
+ }
+ inflight := r.inflight
+ r.mu.Unlock()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-inflight.wait():
+ return inflight.result()
+ }
+}
+
+func (r *remoteKeySet) updateKeys(ctx context.Context) {
+ ctx, span := client.Tracer.Start(ctx, "updateKeys")
+ defer span.End()
+
+ // Sync keys and finish inflight when that's done.
+ keys, err := r.fetchRemoteKeys(ctx)
+
+ r.inflight.done(keys, err)
+
+ // Lock to update the keys and indicate that there is no longer an
+ // inflight request.
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if err == nil {
+ r.cachedKeys = keys
+ }
+
+ // Free inflight so a different request can run.
+ r.inflight = nil
+}
+
+func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
+ ctx, span := client.Tracer.Start(ctx, "fetchRemoteKeys")
+ defer span.End()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", r.jwksURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("oidc: can't create request: %v", err)
+ }
+
+ keySet := new(jsonWebKeySet)
+ if err = httphelper.HttpRequest(r.httpClient, req, keySet); err != nil {
+ return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
+ }
+ return keySet.Keys, nil
+}
+
+// jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty)
+type jsonWebKeySet jose.JSONWebKeySet
+
+// UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error
+// which might occur because of unknown key types (kty)
+func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) {
+ var raw rawJSONWebKeySet
+ err = json.Unmarshal(data, &raw)
+ if err != nil {
+ return err
+ }
+ for _, key := range raw.Keys {
+ webKey := new(jose.JSONWebKey)
+ err = webKey.UnmarshalJSON(key)
+ if err == nil {
+ k.Keys = append(k.Keys, *webKey)
+ }
+ }
+ return nil
+}
+
+type rawJSONWebKeySet struct {
+ Keys []json.RawMessage `json:"keys"`
+}
diff --git a/pkg/client/rp/log.go b/pkg/client/rp/log.go
new file mode 100644
index 0000000..556220c
--- /dev/null
+++ b/pkg/client/rp/log.go
@@ -0,0 +1,17 @@
+package rp
+
+import (
+ "context"
+ "log/slog"
+
+ "github.com/zitadel/logging"
+)
+
+func logCtxWithRPData(ctx context.Context, rp RelyingParty, attrs ...any) context.Context {
+ logger, ok := rp.Logger(ctx)
+ if !ok {
+ return ctx
+ }
+ logger = logger.With(slog.Group("rp", attrs...))
+ return logging.ToContext(ctx, logger)
+}
diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go
new file mode 100644
index 0000000..c2759a2
--- /dev/null
+++ b/pkg/client/rp/relying_party.go
@@ -0,0 +1,820 @@
+package rp
+
+import (
+ "context"
+ "encoding/base64"
+ "errors"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/go-jose/go-jose/v4"
+ "github.com/google/uuid"
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/zitadel/logging"
+)
+
+const (
+ idTokenKey = "id_token"
+ stateParam = "state"
+ pkceCode = "pkce"
+)
+
+var ErrUserInfoSubNotMatching = errors.New("sub from userinfo does not match the sub from the id_token")
+
+// RelyingParty declares the minimal interface for oidc clients
+type RelyingParty interface {
+ // OAuthConfig returns the oauth2 Config
+ OAuthConfig() *oauth2.Config
+
+ // Issuer returns the issuer of the oidc config
+ Issuer() string
+
+ // IsPKCE returns if authorization is done using `Authorization Code Flow with Proof Key for Code Exchange (PKCE)`
+ IsPKCE() bool
+
+ // CookieHandler returns a http cookie handler used for various state transfer cookies
+ CookieHandler() *httphelper.CookieHandler
+
+ // HttpClient returns a http client used for calls to the openid provider, e.g. calling token endpoint
+ HttpClient() *http.Client
+
+ // IsOAuth2Only specifies whether relaying party handles only oauth2 or oidc calls
+ IsOAuth2Only() bool
+
+ // Signer is used if the relaying party uses the JWT Profile
+ Signer() jose.Signer
+
+ // GetEndSessionEndpoint returns the endpoint to sign out on a IDP
+ GetEndSessionEndpoint() string
+
+ // GetRevokeEndpoint returns the endpoint to revoke a specific token
+ GetRevokeEndpoint() string
+
+ // UserinfoEndpoint returns the userinfo
+ UserinfoEndpoint() string
+
+ // GetDeviceAuthorizationEndpoint returns the endpoint which can
+ // be used to start a DeviceAuthorization flow.
+ GetDeviceAuthorizationEndpoint() string
+
+ // IDTokenVerifier returns the verifier used for oidc id_token verification
+ IDTokenVerifier() *IDTokenVerifier
+
+ // ErrorHandler returns the handler used for callback errors
+ ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
+
+ // Logger from the context, or a fallback if set.
+ Logger(context.Context) (logger *slog.Logger, ok bool)
+}
+
+type HasUnauthorizedHandler interface {
+ // UnauthorizedHandler returns the handler used for unauthorized errors
+ UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
+}
+
+type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
+type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)
+
+var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
+ http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
+}
+var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
+ http.Error(w, desc, http.StatusUnauthorized)
+}
+
+type relyingParty struct {
+ issuer string
+ DiscoveryEndpoint string
+ endpoints Endpoints
+ oauthConfig *oauth2.Config
+ oauth2Only bool
+ pkce bool
+ useSigningAlgsFromDiscovery bool
+
+ httpClient *http.Client
+ cookieHandler *httphelper.CookieHandler
+
+ oauthAuthStyle oauth2.AuthStyle
+
+ errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
+ unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
+ idTokenVerifier *IDTokenVerifier
+ verifierOpts []VerifierOption
+ signer jose.Signer
+ logger *slog.Logger
+}
+
+func (rp *relyingParty) OAuthConfig() *oauth2.Config {
+ return rp.oauthConfig
+}
+
+func (rp *relyingParty) Issuer() string {
+ return rp.issuer
+}
+
+func (rp *relyingParty) IsPKCE() bool {
+ return rp.pkce
+}
+
+func (rp *relyingParty) CookieHandler() *httphelper.CookieHandler {
+ return rp.cookieHandler
+}
+
+func (rp *relyingParty) HttpClient() *http.Client {
+ return rp.httpClient
+}
+
+func (rp *relyingParty) IsOAuth2Only() bool {
+ return rp.oauth2Only
+}
+
+func (rp *relyingParty) Signer() jose.Signer {
+ return rp.signer
+}
+
+func (rp *relyingParty) UserinfoEndpoint() string {
+ return rp.endpoints.UserinfoURL
+}
+
+func (rp *relyingParty) GetDeviceAuthorizationEndpoint() string {
+ return rp.endpoints.DeviceAuthorizationURL
+}
+
+func (rp *relyingParty) GetEndSessionEndpoint() string {
+ return rp.endpoints.EndSessionURL
+}
+
+func (rp *relyingParty) GetRevokeEndpoint() string {
+ return rp.endpoints.RevokeURL
+}
+
+func (rp *relyingParty) IDTokenVerifier() *IDTokenVerifier {
+ if rp.idTokenVerifier == nil {
+ rp.idTokenVerifier = NewIDTokenVerifier(rp.issuer, rp.oauthConfig.ClientID, NewRemoteKeySet(rp.httpClient, rp.endpoints.JKWsURL), rp.verifierOpts...)
+ }
+ return rp.idTokenVerifier
+}
+
+func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string) {
+ if rp.errorHandler == nil {
+ rp.errorHandler = DefaultErrorHandler
+ }
+ return rp.errorHandler
+}
+
+func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
+ if rp.unauthorizedHandler == nil {
+ rp.unauthorizedHandler = DefaultUnauthorizedHandler
+ }
+ return rp.unauthorizedHandler
+}
+
+func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
+ logger, ok = logging.FromContext(ctx)
+ if ok {
+ return logger, ok
+ }
+ return rp.logger, rp.logger != nil
+}
+
+// NewRelyingPartyOAuth creates an (OAuth2) RelyingParty with the given
+// OAuth2 Config and possible configOptions
+// it will use the AuthURL and TokenURL set in config
+func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) {
+ rp := &relyingParty{
+ oauthConfig: config,
+ httpClient: httphelper.DefaultHTTPClient,
+ oauth2Only: true,
+ unauthorizedHandler: DefaultUnauthorizedHandler,
+ oauthAuthStyle: oauth2.AuthStyleAutoDetect,
+ }
+
+ for _, optFunc := range options {
+ if err := optFunc(rp); err != nil {
+ return nil, err
+ }
+ }
+
+ rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle
+
+ // avoid races by calling these early
+ _ = rp.IDTokenVerifier() // sets idTokenVerifier
+ _ = rp.ErrorHandler() // sets errorHandler
+ _ = rp.UnauthorizedHandler() // sets unauthorizedHandler
+
+ return rp, nil
+}
+
+// NewRelyingPartyOIDC creates an (OIDC) RelyingParty with the given
+// issuer, clientID, clientSecret, redirectURI, scopes and possible configOptions
+// it will run discovery on the provided issuer and use the found endpoints
+func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...Option) (RelyingParty, error) {
+ rp := &relyingParty{
+ issuer: issuer,
+ oauthConfig: &oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURL: redirectURI,
+ Scopes: scopes,
+ },
+ httpClient: httphelper.DefaultHTTPClient,
+ oauth2Only: false,
+ oauthAuthStyle: oauth2.AuthStyleAutoDetect,
+ }
+
+ for _, optFunc := range options {
+ if err := optFunc(rp); err != nil {
+ return nil, err
+ }
+ }
+ ctx = logCtxWithRPData(ctx, rp, "function", "NewRelyingPartyOIDC")
+ discoveryConfiguration, err := client.Discover(ctx, rp.issuer, rp.httpClient, rp.DiscoveryEndpoint)
+ if err != nil {
+ return nil, err
+ }
+ if rp.useSigningAlgsFromDiscovery {
+ rp.verifierOpts = append(rp.verifierOpts, WithSupportedSigningAlgorithms(discoveryConfiguration.IDTokenSigningAlgValuesSupported...))
+ }
+ endpoints := GetEndpoints(discoveryConfiguration)
+ rp.oauthConfig.Endpoint = endpoints.Endpoint
+ rp.endpoints = endpoints
+
+ rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle
+ rp.endpoints.Endpoint.AuthStyle = rp.oauthAuthStyle
+
+ // avoid races by calling these early
+ _ = rp.IDTokenVerifier() // sets idTokenVerifier
+ _ = rp.ErrorHandler() // sets errorHandler
+ _ = rp.UnauthorizedHandler() // sets unauthorizedHandler
+
+ return rp, nil
+}
+
+// Option is the type for providing dynamic options to the relyingParty
+type Option func(*relyingParty) error
+
+func WithCustomDiscoveryUrl(url string) Option {
+ return func(rp *relyingParty) error {
+ rp.DiscoveryEndpoint = url
+ return nil
+ }
+}
+
+// WithCookieHandler set a `CookieHandler` for securing the various redirects
+func WithCookieHandler(cookieHandler *httphelper.CookieHandler) Option {
+ return func(rp *relyingParty) error {
+ rp.cookieHandler = cookieHandler
+ return nil
+ }
+}
+
+// WithPKCE sets the RP to use PKCE (oauth2 code challenge)
+// it also sets a `CookieHandler` for securing the various redirects
+// and exchanging the code challenge
+func WithPKCE(cookieHandler *httphelper.CookieHandler) Option {
+ return func(rp *relyingParty) error {
+ rp.pkce = true
+ rp.cookieHandler = cookieHandler
+ return nil
+ }
+}
+
+// WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
+func WithHTTPClient(client *http.Client) Option {
+ return func(rp *relyingParty) error {
+ rp.httpClient = client
+ return nil
+ }
+}
+
+func WithErrorHandler(errorHandler ErrorHandler) Option {
+ return func(rp *relyingParty) error {
+ rp.errorHandler = errorHandler
+ return nil
+ }
+}
+
+func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
+ return func(rp *relyingParty) error {
+ rp.unauthorizedHandler = unauthorizedHandler
+ return nil
+ }
+}
+
+func WithAuthStyle(oauthAuthStyle oauth2.AuthStyle) Option {
+ return func(rp *relyingParty) error {
+ rp.oauthAuthStyle = oauthAuthStyle
+ return nil
+ }
+}
+
+func WithVerifierOpts(opts ...VerifierOption) Option {
+ return func(rp *relyingParty) error {
+ rp.verifierOpts = opts
+ return nil
+ }
+}
+
+// WithClientKey specifies the path to the key.json to be used for the JWT Profile Client Authentication on the token endpoint
+//
+// deprecated: use WithJWTProfile(SignerFromKeyPath(path)) instead
+func WithClientKey(path string) Option {
+ return WithJWTProfile(SignerFromKeyPath(path))
+}
+
+// WithJWTProfile creates a signer used for the JWT Profile Client Authentication on the token endpoint
+// When creating the signer, be sure to include the KeyID in the SigningKey.
+// See client.NewSignerFromPrivateKeyByte for an example.
+func WithJWTProfile(signerFromKey SignerFromKey) Option {
+ return func(rp *relyingParty) error {
+ signer, err := signerFromKey()
+ if err != nil {
+ return err
+ }
+ rp.signer = signer
+ return nil
+ }
+}
+
+// WithLogger sets a logger that is used
+// in case the request context does not contain a logger.
+func WithLogger(logger *slog.Logger) Option {
+ return func(rp *relyingParty) error {
+ rp.logger = logger
+ return nil
+ }
+}
+
+// WithSigningAlgsFromDiscovery appends the [WithSupportedSigningAlgorithms] option to the Verifier Options.
+// The algorithms returned in the `id_token_signing_alg_values_supported` from the discovery response will be set.
+func WithSigningAlgsFromDiscovery() Option {
+ return func(rp *relyingParty) error {
+ rp.useSigningAlgsFromDiscovery = true
+ return nil
+ }
+}
+
+type SignerFromKey func() (jose.Signer, error)
+
+func SignerFromKeyPath(path string) SignerFromKey {
+ return func() (jose.Signer, error) {
+ config, err := client.ConfigFromKeyFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return client.NewSignerFromPrivateKeyByte([]byte(config.Key), config.KeyID)
+ }
+}
+
+func SignerFromKeyFile(fileData []byte) SignerFromKey {
+ return func() (jose.Signer, error) {
+ config, err := client.ConfigFromKeyFileData(fileData)
+ if err != nil {
+ return nil, err
+ }
+ return client.NewSignerFromPrivateKeyByte([]byte(config.Key), config.KeyID)
+ }
+}
+
+func SignerFromKeyAndKeyID(key []byte, keyID string) SignerFromKey {
+ return func() (jose.Signer, error) {
+ return client.NewSignerFromPrivateKeyByte(key, keyID)
+ }
+}
+
+// AuthURL returns the auth request url
+// (wrapping the oauth2 `AuthCodeURL`)
+func AuthURL(state string, rp RelyingParty, opts ...AuthURLOpt) string {
+ authOpts := make([]oauth2.AuthCodeOption, 0)
+ for _, opt := range opts {
+ authOpts = append(authOpts, opt()...)
+ }
+ return rp.OAuthConfig().AuthCodeURL(state, authOpts...)
+}
+
+// AuthURLHandler extends the `AuthURL` method with a http redirect handler
+// including handling setting cookie for secure `state` transfer.
+// Custom parameters can optionally be set to the redirect URL.
+func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ opts := make([]AuthURLOpt, len(urlParam))
+ for i, p := range urlParam {
+ opts[i] = AuthURLOpt(p)
+ }
+
+ state := stateFn()
+ if err := trySetStateCookie(w, state, rp); err != nil {
+ unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
+ return
+ }
+ if rp.IsPKCE() {
+ codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
+ if err != nil {
+ unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
+ return
+ }
+ opts = append(opts, WithCodeChallenge(codeChallenge))
+ }
+
+ http.Redirect(w, r, AuthURL(state, rp, opts...), http.StatusFound)
+ }
+}
+
+// GenerateAndStoreCodeChallenge generates a PKCE code challenge and stores its verifier into a secure cookie
+func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (string, error) {
+ codeVerifier := base64.RawURLEncoding.EncodeToString([]byte(uuid.New().String()))
+ if err := rp.CookieHandler().SetCookie(w, pkceCode, codeVerifier); err != nil {
+ return "", err
+ }
+ return oidc.NewSHACodeChallenge(codeVerifier), nil
+}
+
+// ErrMissingIDToken is returned when an id_token was expected,
+// but not received in the token response.
+var ErrMissingIDToken = errors.New("id_token missing")
+
+func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) {
+ ctx, span := client.Tracer.Start(ctx, "verifyTokenResponse")
+ defer span.End()
+
+ if rp.IsOAuth2Only() {
+ return &oidc.Tokens[C]{Token: token}, nil
+ }
+ idTokenString, ok := token.Extra(idTokenKey).(string)
+ if !ok {
+ return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken
+ }
+ idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier())
+ if err != nil {
+ return nil, err
+ }
+ return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
+}
+
+// CodeExchange handles the oauth2 code exchange, extracting and validating the id_token
+// returning it parsed together with the oauth2 tokens (access, refresh)
+func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) {
+ ctx, codeExchangeSpan := client.Tracer.Start(ctx, "CodeExchange")
+ defer codeExchangeSpan.End()
+
+ ctx = logCtxWithRPData(ctx, rp, "function", "CodeExchange")
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
+ codeOpts := make([]oauth2.AuthCodeOption, 0)
+ for _, opt := range opts {
+ codeOpts = append(codeOpts, opt()...)
+ }
+
+ ctx, oauthExchangeSpan := client.Tracer.Start(ctx, "OAuthExchange")
+ token, err := rp.OAuthConfig().Exchange(ctx, code, codeOpts...)
+ if err != nil {
+ return nil, err
+ }
+ oauthExchangeSpan.End()
+ return verifyTokenResponse[C](ctx, token, rp)
+}
+
+// ClientCredentials requests an access token using the `client_credentials` grant,
+// as defined in [RFC 6749, section 4.4].
+//
+// As there is no user associated to the request an ID Token can never be returned.
+// Client Credentials are undefined in OpenID Connect and is a pure OAuth2 grant.
+// Furthermore the server SHOULD NOT return a refresh token.
+//
+// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
+func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) {
+ ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials")
+ ctx, span := client.Tracer.Start(ctx, "ClientCredentials")
+ defer span.End()
+
+ ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
+ config := clientcredentials.Config{
+ ClientID: rp.OAuthConfig().ClientID,
+ ClientSecret: rp.OAuthConfig().ClientSecret,
+ TokenURL: rp.OAuthConfig().Endpoint.TokenURL,
+ Scopes: rp.OAuthConfig().Scopes,
+ EndpointParams: endpointParams,
+ AuthStyle: rp.OAuthConfig().Endpoint.AuthStyle,
+ }
+ return config.Token(ctx)
+}
+
+type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
+
+// CodeExchangeHandler extends the `CodeExchange` method with a http handler
+// including cookie handling for secure `state` transfer
+// and optional PKCE code verifier checking.
+// Custom parameters can optionally be set to the token URL.
+func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx, span := client.Tracer.Start(r.Context(), "CodeExchangeHandler")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ state, err := tryReadStateCookie(w, r, rp)
+ if err != nil {
+ unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
+ return
+ }
+ if errValue := r.FormValue("error"); errValue != "" {
+ rp.ErrorHandler()(w, r, errValue, r.FormValue("error_description"), state)
+ return
+ }
+ codeOpts := make([]CodeExchangeOpt, len(urlParam))
+ for i, p := range urlParam {
+ codeOpts[i] = CodeExchangeOpt(p)
+ }
+
+ if rp.IsPKCE() {
+ codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
+ if err != nil {
+ unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)
+ return
+ }
+ codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
+ rp.CookieHandler().DeleteCookie(w, pkceCode)
+ }
+ if rp.Signer() != nil {
+ assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer(), rp.OAuthConfig().Endpoint.TokenURL}, time.Hour, rp.Signer())
+ if err != nil {
+ unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)
+ return
+ }
+ codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
+ }
+ tokens, err := CodeExchange[C](r.Context(), r.FormValue("code"), rp, codeOpts...)
+ if err != nil {
+ unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
+ return
+ }
+ callback(w, r, tokens, state, rp)
+ }
+}
+
+type SubjectGetter interface {
+ GetSubject() string
+}
+
+type CodeExchangeUserinfoCallback[C oidc.IDClaims, U SubjectGetter] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, provider RelyingParty, info U)
+
+// UserinfoCallback wraps the callback function of the CodeExchangeHandler
+// and calls the userinfo endpoint with the access token
+// on success it will pass the userinfo into its callback function as well
+func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCallback[C, U]) CodeExchangeCallback[C] {
+ return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
+ ctx, span := client.Tracer.Start(r.Context(), "UserinfoCallback")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
+ if err != nil {
+ unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
+ return
+ }
+ f(w, r, tokens, state, rp, info)
+ }
+}
+
+// Userinfo will call the OIDC [UserInfo] Endpoint with the provided token and returns
+// the response in an instance of type U.
+// [*oidc.UserInfo] can be used as a good example, or use a custom type if type-safe
+// access to custom claims is needed.
+//
+// [UserInfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
+func Userinfo[U SubjectGetter](ctx context.Context, token, tokenType, subject string, rp RelyingParty) (userinfo U, err error) {
+ var nilU U
+ ctx = logCtxWithRPData(ctx, rp, "function", "Userinfo")
+ ctx, span := client.Tracer.Start(ctx, "Userinfo")
+ defer span.End()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, rp.UserinfoEndpoint(), nil)
+ if err != nil {
+ return nilU, err
+ }
+ req.Header.Set("authorization", tokenType+" "+token)
+ if err := httphelper.HttpRequest(rp.HttpClient(), req, &userinfo); err != nil {
+ return nilU, err
+ }
+ if userinfo.GetSubject() != subject {
+ return nilU, ErrUserInfoSubNotMatching
+ }
+ return userinfo, nil
+}
+
+func trySetStateCookie(w http.ResponseWriter, state string, rp RelyingParty) error {
+ if rp.CookieHandler() != nil {
+ if err := rp.CookieHandler().SetCookie(w, stateParam, state); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func tryReadStateCookie(w http.ResponseWriter, r *http.Request, rp RelyingParty) (state string, err error) {
+ if rp.CookieHandler() == nil {
+ return r.FormValue(stateParam), nil
+ }
+ state, err = rp.CookieHandler().CheckQueryCookie(r, stateParam)
+ if err != nil {
+ return "", err
+ }
+ rp.CookieHandler().DeleteCookie(w, stateParam)
+ return state, nil
+}
+
+type OptionFunc func(RelyingParty)
+
+type Endpoints struct {
+ oauth2.Endpoint
+ IntrospectURL string
+ UserinfoURL string
+ JKWsURL string
+ EndSessionURL string
+ RevokeURL string
+ DeviceAuthorizationURL string
+}
+
+func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
+ return Endpoints{
+ Endpoint: oauth2.Endpoint{
+ AuthURL: discoveryConfig.AuthorizationEndpoint,
+ TokenURL: discoveryConfig.TokenEndpoint,
+ },
+ IntrospectURL: discoveryConfig.IntrospectionEndpoint,
+ UserinfoURL: discoveryConfig.UserinfoEndpoint,
+ JKWsURL: discoveryConfig.JwksURI,
+ EndSessionURL: discoveryConfig.EndSessionEndpoint,
+ RevokeURL: discoveryConfig.RevocationEndpoint,
+ DeviceAuthorizationURL: discoveryConfig.DeviceAuthorizationEndpoint,
+ }
+}
+
+// withURLParam sets custom url parameters.
+// This is the generalized, unexported, function used by both
+// URLParamOpt and AuthURLOpt.
+func withURLParam(key, value string) func() []oauth2.AuthCodeOption {
+ return func() []oauth2.AuthCodeOption {
+ return []oauth2.AuthCodeOption{
+ oauth2.SetAuthURLParam(key, value),
+ }
+ }
+}
+
+// withPrompt sets the `prompt` params in the auth request
+// This is the generalized, unexported, function used by both
+// URLParamOpt and AuthURLOpt.
+func withPrompt(prompt ...string) func() []oauth2.AuthCodeOption {
+ return withURLParam("prompt", oidc.SpaceDelimitedArray(prompt).String())
+}
+
+type URLParamOpt func() []oauth2.AuthCodeOption
+
+// WithURLParam allows setting custom key-vale pairs
+// to an OAuth2 URL.
+func WithURLParam(key, value string) URLParamOpt {
+ return withURLParam(key, value)
+}
+
+// WithPromptURLParam sets the `prompt` parameter in a URL.
+func WithPromptURLParam(prompt ...string) URLParamOpt {
+ return withPrompt(prompt...)
+}
+
+// WithResponseModeURLParam sets the `response_mode` parameter in a URL.
+func WithResponseModeURLParam(mode oidc.ResponseMode) URLParamOpt {
+ return withURLParam("response_mode", string(mode))
+}
+
+type AuthURLOpt func() []oauth2.AuthCodeOption
+
+// WithCodeChallenge sets the `code_challenge` params in the auth request
+func WithCodeChallenge(codeChallenge string) AuthURLOpt {
+ return func() []oauth2.AuthCodeOption {
+ return []oauth2.AuthCodeOption{
+ oauth2.SetAuthURLParam("code_challenge", codeChallenge),
+ oauth2.SetAuthURLParam("code_challenge_method", "S256"),
+ }
+ }
+}
+
+// WithPrompt sets the `prompt` params in the auth request
+func WithPrompt(prompt ...string) AuthURLOpt {
+ return withPrompt(prompt...)
+}
+
+type CodeExchangeOpt func() []oauth2.AuthCodeOption
+
+// WithCodeVerifier sets the `code_verifier` param in the token request
+func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
+ return func() []oauth2.AuthCodeOption {
+ return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
+ }
+}
+
+// WithClientAssertionJWT sets the `client_assertion` param in the token request
+func WithClientAssertionJWT(clientAssertion string) CodeExchangeOpt {
+ return func() []oauth2.AuthCodeOption {
+ return client.ClientAssertionCodeOptions(clientAssertion)
+ }
+}
+
+type tokenEndpointCaller struct {
+ RelyingParty
+}
+
+func (t tokenEndpointCaller) TokenEndpoint() string {
+ return t.OAuthConfig().Endpoint.TokenURL
+}
+
+type RefreshTokenRequest struct {
+ RefreshToken string `schema:"refresh_token"`
+ Scopes oidc.SpaceDelimitedArray `schema:"scope,omitempty"`
+ ClientID string `schema:"client_id,omitempty"`
+ ClientSecret string `schema:"client_secret,omitempty"`
+ ClientAssertion string `schema:"client_assertion,omitempty"`
+ ClientAssertionType string `schema:"client_assertion_type,omitempty"`
+ GrantType oidc.GrantType `schema:"grant_type"`
+}
+
+// RefreshTokens performs a token refresh. If it doesn't error, it will always
+// provide a new AccessToken. It may provide a new RefreshToken, and if it does, then
+// the old one should be considered invalid.
+//
+// In case the RP is not OAuth2 only and an IDToken was part of the response,
+// the IDToken and AccessToken will be verified
+// and the IDToken and IDTokenClaims fields will be populated in the returned object.
+func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) {
+ ctx, span := client.Tracer.Start(ctx, "RefreshTokens")
+ defer span.End()
+
+ ctx = logCtxWithRPData(ctx, rp, "function", "RefreshTokens")
+ request := RefreshTokenRequest{
+ RefreshToken: refreshToken,
+ Scopes: rp.OAuthConfig().Scopes,
+ ClientID: rp.OAuthConfig().ClientID,
+ ClientSecret: rp.OAuthConfig().ClientSecret,
+ ClientAssertion: clientAssertion,
+ ClientAssertionType: clientAssertionType,
+ GrantType: oidc.GrantTypeRefreshToken,
+ }
+ newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp})
+ if err != nil {
+ return nil, err
+ }
+ tokens, err := verifyTokenResponse[C](ctx, newToken, rp)
+ if err == nil || errors.Is(err, ErrMissingIDToken) {
+ // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
+ // ...except that it might not contain an id_token.
+ return tokens, nil
+ }
+ return nil, err
+}
+
+func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) {
+ ctx = logCtxWithRPData(ctx, rp, "function", "EndSession")
+ ctx, span := client.Tracer.Start(ctx, "RefreshTokens")
+ defer span.End()
+
+ request := oidc.EndSessionRequest{
+ IdTokenHint: idToken,
+ ClientID: rp.OAuthConfig().ClientID,
+ PostLogoutRedirectURI: optionalRedirectURI,
+ State: optionalState,
+ }
+ return client.CallEndSessionEndpoint(ctx, request, nil, rp)
+}
+
+// RevokeToken requires a RelyingParty that is also a client.RevokeCaller. The RelyingParty
+// returned by NewRelyingPartyOIDC() meets that criteria, but the one returned by
+// NewRelyingPartyOAuth() does not.
+//
+// tokenTypeHint should be either "id_token" or "refresh_token".
+func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHint string) error {
+ ctx = logCtxWithRPData(ctx, rp, "function", "RevokeToken")
+ ctx, span := client.Tracer.Start(ctx, "RefreshTokens")
+ defer span.End()
+ request := client.RevokeRequest{
+ Token: token,
+ TokenTypeHint: tokenTypeHint,
+ ClientID: rp.OAuthConfig().ClientID,
+ ClientSecret: rp.OAuthConfig().ClientSecret,
+ }
+ if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
+ return client.CallRevokeEndpoint(ctx, request, nil, rc)
+ }
+ return ErrRelyingPartyNotSupportRevokeCaller
+}
+
+func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
+ if rp, ok := rp.(HasUnauthorizedHandler); ok {
+ rp.UnauthorizedHandler()(w, r, desc, state)
+ return
+ }
+ http.Error(w, desc, http.StatusUnauthorized)
+}
diff --git a/pkg/client/rp/relying_party_test.go b/pkg/client/rp/relying_party_test.go
new file mode 100644
index 0000000..b3bb6ee
--- /dev/null
+++ b/pkg/client/rp/relying_party_test.go
@@ -0,0 +1,107 @@
+package rp
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+)
+
+func Test_verifyTokenResponse(t *testing.T) {
+ verifier := &IDTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: 2 * time.Minute,
+ ClientID: tu.ValidClientID,
+ Offset: time.Second,
+ SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
+ KeySet: tu.KeySet{},
+ MaxAge: 2 * time.Minute,
+ ACR: tu.ACRVerify,
+ Nonce: func(context.Context) string { return tu.ValidNonce },
+ }
+ tests := []struct {
+ name string
+ oauth2Only bool
+ tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims])
+ wantErr error
+ }{
+ {
+ name: "succes, oauth2 only",
+ oauth2Only: true,
+ tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
+ accesToken, _ := tu.ValidAccessToken()
+ token := &oauth2.Token{
+ AccessToken: accesToken,
+ }
+ return token, &oidc.Tokens[*oidc.IDTokenClaims]{
+ Token: token,
+ }
+ },
+ },
+ {
+ name: "id_token missing error",
+ oauth2Only: false,
+ tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
+ accesToken, _ := tu.ValidAccessToken()
+ token := &oauth2.Token{
+ AccessToken: accesToken,
+ }
+ return token, &oidc.Tokens[*oidc.IDTokenClaims]{
+ Token: token,
+ }
+ },
+ wantErr: ErrMissingIDToken,
+ },
+ {
+ name: "verify tokens error",
+ oauth2Only: false,
+ tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
+ accesToken, _ := tu.ValidAccessToken()
+ token := &oauth2.Token{
+ AccessToken: accesToken,
+ }
+ token = token.WithExtra(map[string]any{
+ "id_token": "foobar",
+ })
+ return token, nil
+ },
+ wantErr: oidc.ErrParse,
+ },
+ {
+ name: "success, with id_token",
+ oauth2Only: false,
+ tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) {
+ accesToken, _ := tu.ValidAccessToken()
+ token := &oauth2.Token{
+ AccessToken: accesToken,
+ }
+ idToken, claims := tu.ValidIDToken()
+ token = token.WithExtra(map[string]any{
+ "id_token": idToken,
+ })
+ return token, &oidc.Tokens[*oidc.IDTokenClaims]{
+ Token: token,
+ IDTokenClaims: claims,
+ IDToken: idToken,
+ }
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rp := &relyingParty{
+ oauth2Only: tt.oauth2Only,
+ idTokenVerifier: verifier,
+ }
+ token, want := tt.tokens()
+ got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, want, got)
+ })
+ }
+}
diff --git a/pkg/client/rp/tockenexchange.go b/pkg/client/rp/tockenexchange.go
new file mode 100644
index 0000000..aa2cf99
--- /dev/null
+++ b/pkg/client/rp/tockenexchange.go
@@ -0,0 +1,27 @@
+package rp
+
+import (
+ "context"
+
+ "golang.org/x/oauth2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc/grants/tokenexchange"
+)
+
+// TokenExchangeRP extends the `RelyingParty` interface for the *draft* oauth2 `Token Exchange`
+type TokenExchangeRP interface {
+ RelyingParty
+
+ // TokenExchange implement the `Token Exchange Grant` exchanging some token for an other
+ TokenExchange(context.Context, *tokenexchange.TokenExchangeRequest) (*oauth2.Token, error)
+}
+
+// DelegationTokenExchangeRP extends the `TokenExchangeRP` interface
+// for the specific `delegation token` request
+type DelegationTokenExchangeRP interface {
+ TokenExchangeRP
+
+ // DelegationTokenExchange implement the `Token Exchange Grant`
+ // providing an access token in request for a `delegation` token for a given resource / audience
+ DelegationTokenExchange(context.Context, string, ...tokenexchange.TokenExchangeOption) (*oauth2.Token, error)
+}
diff --git a/pkg/client/rp/userinfo_example_test.go b/pkg/client/rp/userinfo_example_test.go
new file mode 100644
index 0000000..78e014e
--- /dev/null
+++ b/pkg/client/rp/userinfo_example_test.go
@@ -0,0 +1,45 @@
+package rp_test
+
+import (
+ "context"
+ "fmt"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type UserInfo struct {
+ Subject string `json:"sub,omitempty"`
+ oidc.UserInfoProfile
+ oidc.UserInfoEmail
+ oidc.UserInfoPhone
+ Address *oidc.UserInfoAddress `json:"address,omitempty"`
+
+ // Foo and Bar are custom claims
+ Foo string `json:"foo,omitempty"`
+ Bar struct {
+ Val1 string `json:"val_1,omitempty"`
+ Val2 string `json:"val_2,omitempty"`
+ } `json:"bar,omitempty"`
+
+ // Claims are all the combined claims, including custom.
+ Claims map[string]any `json:"-,omitempty"`
+}
+
+func (u *UserInfo) GetSubject() string {
+ return u.Subject
+}
+
+func ExampleUserinfo_custom() {
+ rpo, err := rp.NewRelyingPartyOIDC(context.TODO(), "http://localhost:8080", "clientid", "clientsecret", "http://example.com/redirect", []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone})
+ if err != nil {
+ panic(err)
+ }
+
+ info, err := rp.Userinfo[*UserInfo](context.TODO(), "accesstokenstring", "Bearer", "userid", rpo)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Println(info)
+}
diff --git a/pkg/client/rp/verifier.go b/pkg/client/rp/verifier.go
new file mode 100644
index 0000000..0088b81
--- /dev/null
+++ b/pkg/client/rp/verifier.go
@@ -0,0 +1,174 @@
+package rp
+
+import (
+ "context"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// VerifyTokens implement the Token Response Validation as defined in OIDC specification
+// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
+func VerifyTokens[C oidc.IDClaims](ctx context.Context, accessToken, idToken string, v *IDTokenVerifier) (claims C, err error) {
+ ctx, span := client.Tracer.Start(ctx, "VerifyTokens")
+ defer span.End()
+
+ var nilClaims C
+
+ claims, err = VerifyIDToken[C](ctx, idToken, v)
+ if err != nil {
+ return nilClaims, err
+ }
+ if err := VerifyAccessToken(accessToken, claims.GetAccessTokenHash(), claims.GetSignatureAlgorithm()); err != nil {
+ return nilClaims, err
+ }
+ return claims, nil
+}
+
+// VerifyIDToken validates the id token according to
+// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
+func VerifyIDToken[C oidc.Claims](ctx context.Context, token string, v *IDTokenVerifier) (claims C, err error) {
+ ctx, span := client.Tracer.Start(ctx, "VerifyIDToken")
+ defer span.End()
+
+ var nilClaims C
+
+ decrypted, err := oidc.DecryptToken(token)
+ if err != nil {
+ return nilClaims, err
+ }
+ payload, err := oidc.ParseToken(decrypted, &claims)
+ if err != nil {
+ return nilClaims, err
+ }
+
+ if err := oidc.CheckSubject(claims); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckIssuer(claims, v.Issuer); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckAudience(claims, v.ClientID); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckAuthorizedParty(claims, v.ClientID); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
+ return nilClaims, err
+ }
+
+ if v.Nonce != nil {
+ if err = oidc.CheckNonce(claims, v.Nonce(ctx)); err != nil {
+ return nilClaims, err
+ }
+ }
+
+ if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
+ return nilClaims, err
+ }
+ return claims, nil
+}
+
+type IDTokenVerifier oidc.Verifier
+
+// VerifyAccessToken validates the access token according to
+// https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
+func VerifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
+ if atHash == "" {
+ return nil
+ }
+
+ actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
+ if err != nil {
+ return err
+ }
+ if actual != atHash {
+ return oidc.ErrAtHash
+ }
+ return nil
+}
+
+// NewIDTokenVerifier returns a oidc.Verifier suitable for ID token verification.
+func NewIDTokenVerifier(issuer, clientID string, keySet oidc.KeySet, options ...VerifierOption) *IDTokenVerifier {
+ v := &IDTokenVerifier{
+ Issuer: issuer,
+ ClientID: clientID,
+ KeySet: keySet,
+ Offset: time.Second,
+ Nonce: func(_ context.Context) string {
+ return ""
+ },
+ }
+
+ for _, opts := range options {
+ opts(v)
+ }
+
+ return v
+}
+
+// VerifierOption is the type for providing dynamic options to the IDTokenVerifier
+type VerifierOption func(*IDTokenVerifier)
+
+// WithIssuedAtOffset mitigates the risk of iat to be in the future
+// because of clock skews with the ability to add an offset to the current time
+func WithIssuedAtOffset(offset time.Duration) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.Offset = offset
+ }
+}
+
+// WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
+func WithIssuedAtMaxAge(maxAge time.Duration) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.MaxAgeIAT = maxAge
+ }
+}
+
+// WithNonce sets the function to check the nonce
+func WithNonce(nonce func(context.Context) string) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.Nonce = nonce
+ }
+}
+
+// WithACRVerifier sets the verifier for the acr claim
+func WithACRVerifier(verifier oidc.ACRVerifier) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.ACR = verifier
+ }
+}
+
+// WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
+func WithAuthTimeMaxAge(maxAge time.Duration) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.MaxAge = maxAge
+ }
+}
+
+// WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
+func WithSupportedSigningAlgorithms(algs ...string) VerifierOption {
+ return func(v *IDTokenVerifier) {
+ v.SupportedSignAlgs = algs
+ }
+}
diff --git a/pkg/client/rp/verifier_test.go b/pkg/client/rp/verifier_test.go
new file mode 100644
index 0000000..38f5a4a
--- /dev/null
+++ b/pkg/client/rp/verifier_test.go
@@ -0,0 +1,359 @@
+package rp
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestVerifyTokens(t *testing.T) {
+ verifier := &IDTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: 2 * time.Minute,
+ Offset: time.Second,
+ SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
+ KeySet: tu.KeySet{},
+ MaxAge: 2 * time.Minute,
+ ACR: tu.ACRVerify,
+ Nonce: func(context.Context) string { return tu.ValidNonce },
+ ClientID: tu.ValidClientID,
+ }
+ accessToken, _ := tu.ValidAccessToken()
+ atHash, err := oidc.ClaimHash(accessToken, tu.SignatureAlgorithm)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ accessToken string
+ idTokenClaims func() (string, *oidc.IDTokenClaims)
+ wantErr bool
+ }{
+ {
+ name: "without access token",
+ idTokenClaims: tu.ValidIDToken,
+ },
+ {
+ name: "with access token",
+ accessToken: accessToken,
+ idTokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
+ )
+ },
+ },
+ {
+ name: "expired id token",
+ accessToken: accessToken,
+ idTokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, atHash,
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong access token",
+ accessToken: accessToken,
+ idTokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "~~~",
+ )
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ idToken, want := tt.idTokenClaims()
+ got, err := VerifyTokens[*oidc.IDTokenClaims](context.Background(), tt.accessToken, idToken, verifier)
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, got)
+ return
+ }
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ assert.Equal(t, got, want)
+ })
+ }
+}
+
+func TestVerifyIDToken(t *testing.T) {
+ verifier := &IDTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: 2 * time.Minute,
+ Offset: time.Second,
+ SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
+ KeySet: tu.KeySet{},
+ MaxAge: 2 * time.Minute,
+ ACR: tu.ACRVerify,
+ Nonce: func(context.Context) string { return tu.ValidNonce },
+ ClientID: tu.ValidClientID,
+ }
+
+ tests := []struct {
+ name string
+ tokenClaims func() (string, *oidc.IDTokenClaims)
+ customVerifier func(verifier *IDTokenVerifier)
+ wantErr bool
+ }{
+ {
+ name: "success",
+ tokenClaims: tu.ValidIDToken,
+ },
+ {
+ name: "custom claims",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDTokenCustom(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ map[string]any{"some": "thing"},
+ )
+ },
+ },
+ {
+ name: "skip nonce check",
+ customVerifier: func(verifier *IDTokenVerifier) {
+ verifier.Nonce = nil
+ },
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, "foo",
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ },
+ {
+ name: "parse err",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
+ wantErr: true,
+ },
+ {
+ name: "invalid signature",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
+ wantErr: true,
+ },
+ {
+ name: "empty subject",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, "", tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong issuer",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ "foo", tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong clientID",
+ customVerifier: func(verifier *IDTokenVerifier) {
+ verifier.ClientID = "foo"
+ },
+ tokenClaims: tu.ValidIDToken,
+ wantErr: true,
+ },
+ {
+ name: "expired",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong IAT",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong acr",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "expired auth",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "wrong nonce",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, "foo",
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ token, want := tt.tokenClaims()
+ if tt.customVerifier != nil {
+ tt.customVerifier(verifier)
+ }
+
+ got, err := VerifyIDToken[*oidc.IDTokenClaims](context.Background(), token, verifier)
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, got)
+ return
+ }
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ assert.Equal(t, got, want)
+ })
+ }
+}
+
+func TestVerifyAccessToken(t *testing.T) {
+ token, _ := tu.ValidAccessToken()
+ hash, err := oidc.ClaimHash(token, tu.SignatureAlgorithm)
+ require.NoError(t, err)
+
+ type args struct {
+ accessToken string
+ atHash string
+ sigAlgorithm jose.SignatureAlgorithm
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ name: "empty hash",
+ },
+ {
+ name: "success",
+ args: args{
+ accessToken: token,
+ atHash: hash,
+ sigAlgorithm: tu.SignatureAlgorithm,
+ },
+ },
+ {
+ name: "invalid algorithm",
+ args: args{
+ accessToken: token,
+ atHash: hash,
+ sigAlgorithm: "foo",
+ },
+ wantErr: true,
+ },
+ {
+ name: "mismatch",
+ args: args{
+ accessToken: token,
+ atHash: "~~",
+ sigAlgorithm: tu.SignatureAlgorithm,
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := VerifyAccessToken(tt.args.accessToken, tt.args.atHash, tt.args.sigAlgorithm)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestNewIDTokenVerifier(t *testing.T) {
+ type args struct {
+ issuer string
+ clientID string
+ keySet oidc.KeySet
+ options []VerifierOption
+ }
+ tests := []struct {
+ name string
+ args args
+ want *IDTokenVerifier
+ }{
+ {
+ name: "nil nonce", // otherwise assert.Equal will fail on the function
+ args: args{
+ issuer: tu.ValidIssuer,
+ clientID: tu.ValidClientID,
+ keySet: tu.KeySet{},
+ options: []VerifierOption{
+ WithIssuedAtOffset(time.Minute),
+ WithIssuedAtMaxAge(time.Hour),
+ WithNonce(nil), // otherwise assert.Equal will fail on the function
+ WithACRVerifier(nil),
+ WithAuthTimeMaxAge(2 * time.Hour),
+ WithSupportedSigningAlgorithms("ABC", "DEF"),
+ },
+ },
+ want: &IDTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ Offset: time.Minute,
+ MaxAgeIAT: time.Hour,
+ ClientID: tu.ValidClientID,
+ KeySet: tu.KeySet{},
+ Nonce: nil,
+ ACR: nil,
+ MaxAge: 2 * time.Hour,
+ SupportedSignAlgs: []string{"ABC", "DEF"},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewIDTokenVerifier(tt.args.issuer, tt.args.clientID, tt.args.keySet, tt.args.options...)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/pkg/client/rp/verifier_tokens_example_test.go b/pkg/client/rp/verifier_tokens_example_test.go
new file mode 100644
index 0000000..7ae68d6
--- /dev/null
+++ b/pkg/client/rp/verifier_tokens_example_test.go
@@ -0,0 +1,86 @@
+package rp_test
+
+import (
+ "context"
+ "fmt"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rp"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// MyCustomClaims extends the TokenClaims base,
+// so it implmeents the oidc.Claims interface.
+// Instead of carrying a map, we add needed fields// to the struct for type safe access.
+type MyCustomClaims struct {
+ oidc.TokenClaims
+ NotBefore oidc.Time `json:"nbf,omitempty"`
+ AccessTokenHash string `json:"at_hash,omitempty"`
+ Foo string `json:"foo,omitempty"`
+ Bar *Nested `json:"bar,omitempty"`
+}
+
+// GetAccessTokenHash is required to implement
+// the oidc.IDClaims interface.
+func (c *MyCustomClaims) GetAccessTokenHash() string {
+ return c.AccessTokenHash
+}
+
+// Nested struct types are also possible.
+type Nested struct {
+ Count int `json:"count,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+}
+
+/*
+idToken carries the following claims. foo and bar are custom claims
+
+ {
+ "acr": "something",
+ "amr": [
+ "foo",
+ "bar"
+ ],
+ "at_hash": "2dzbm_vIxy-7eRtqUIGPPw",
+ "aud": [
+ "unit",
+ "test",
+ "555666"
+ ],
+ "auth_time": 1678100961,
+ "azp": "555666",
+ "bar": {
+ "count": 22,
+ "tags": [
+ "some",
+ "tags"
+ ]
+ },
+ "client_id": "555666",
+ "exp": 4802238682,
+ "foo": "Hello, World!",
+ "iat": 1678101021,
+ "iss": "local.com",
+ "jti": "9876",
+ "nbf": 1678101021,
+ "nonce": "12345",
+ "sub": "tim@local.com"
+ }
+*/
+const idToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhY3IiOiJzb21ldGhpbmciLCJhbXIiOlsiZm9vIiwiYmFyIl0sImF0X2hhc2giOiIyZHpibV92SXh5LTdlUnRxVUlHUFB3IiwiYXVkIjpbInVuaXQiLCJ0ZXN0IiwiNTU1NjY2Il0sImF1dGhfdGltZSI6MTY3ODEwMDk2MSwiYXpwIjoiNTU1NjY2IiwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiY2xpZW50X2lkIjoiNTU1NjY2IiwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJub25jZSI6IjEyMzQ1Iiwic3ViIjoidGltQGxvY2FsLmNvbSJ9.t3GXSfVNNwiW1Suv9_84v0sdn2_-RWHVxhphhRozDXnsO7SDNOlGnEioemXABESxSzMclM7gB7mYy5Qah2ZUNx7eP5t2njoxEYfavgHwx7UJZ2NCg8NDPQyr-hlxelEcfdXK-I0oTd-FRDvF4rqPkD9Us52IpnplChCxnHFgh4wKwPqZZjv2IXVCtn0ilKW3hff1rMOYKEuLRcN2YP0gkyuqyHvcf2dMmjod0t4sLOTJ82rsCbMBC5CLpqv3nIC9HOGITkt1Kd-Am0n1LrdZvWwTo6RFe8AnzF0gpqjcB5Wg4Qeh58DIjZOz4f_8wnmJ_gCqyRh5vfSW4XHdbum0Tw`
+const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM4NjgyLCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MTAxMDIxLCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MTAxMDIxLCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.Zrz3LWSRjCMJZUMaI5dUbW4vGdSmEeJQ3ouhaX0bcW9rdFFLgBI4K2FWJhNivq8JDmCGSxwLu3mI680GWmDaEoAx1M5sCO9lqfIZHGZh-lfAXk27e6FPLlkTDBq8Bx4o4DJ9Fw0hRJGjUTjnYv5cq1vo2-UqldasL6CwTbkzNC_4oQFfRtuodC4Ql7dZ1HRv5LXuYx7KPkOssLZtV9cwtJp5nFzKjcf2zEE_tlbjcpynMwypornRUp1EhCWKRUGkJhJeiP71ECY5pQhShfjBu9Nc5wDpSnZmnk2S4YsPrRK3QkE-iEkas8BfsOCrGoErHjEJexAIDjasGO5PFLWfCA`
+
+func ExampleVerifyTokens_customClaims() {
+ v := rp.NewIDTokenVerifier("local.com", "555666", tu.KeySet{},
+ rp.WithNonce(func(ctx context.Context) string { return "12345" }),
+ )
+
+ // VerifyAccessToken can be called with the *MyCustomClaims.
+ claims, err := rp.VerifyTokens[*MyCustomClaims](context.TODO(), accessToken, idToken, v)
+ if err != nil {
+ panic(err)
+ }
+ // Here we have typesafe access to the custom claims
+ fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
+ // Output: Hello, World! 22 [some tags]
+}
diff --git a/pkg/client/rs/introspect_example_test.go b/pkg/client/rs/introspect_example_test.go
new file mode 100644
index 0000000..1f67d11
--- /dev/null
+++ b/pkg/client/rs/introspect_example_test.go
@@ -0,0 +1,52 @@
+package rs_test
+
+import (
+ "context"
+ "fmt"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client/rs"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type IntrospectionResponse struct {
+ Active bool `json:"active"`
+ Scope oidc.SpaceDelimitedArray `json:"scope,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ TokenType string `json:"token_type,omitempty"`
+ Expiration oidc.Time `json:"exp,omitempty"`
+ IssuedAt oidc.Time `json:"iat,omitempty"`
+ NotBefore oidc.Time `json:"nbf,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Audience oidc.Audience `json:"aud,omitempty"`
+ Issuer string `json:"iss,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ Username string `json:"username,omitempty"`
+ oidc.UserInfoProfile
+ oidc.UserInfoEmail
+ oidc.UserInfoPhone
+ Address *oidc.UserInfoAddress `json:"address,omitempty"`
+
+ // Foo and Bar are custom claims
+ Foo string `json:"foo,omitempty"`
+ Bar struct {
+ Val1 string `json:"val_1,omitempty"`
+ Val2 string `json:"val_2,omitempty"`
+ } `json:"bar,omitempty"`
+
+ // Claims are all the combined claims, including custom.
+ Claims map[string]any `json:"-,omitempty"`
+}
+
+func ExampleIntrospect_custom() {
+ rss, err := rs.NewResourceServerClientCredentials(context.TODO(), "http://localhost:8080", "clientid", "clientsecret")
+ if err != nil {
+ panic(err)
+ }
+
+ resp, err := rs.Introspect[*IntrospectionResponse](context.TODO(), rss, "accesstokenstring")
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Println(resp)
+}
diff --git a/pkg/client/rs/resource_server.go b/pkg/client/rs/resource_server.go
new file mode 100644
index 0000000..993796e
--- /dev/null
+++ b/pkg/client/rs/resource_server.go
@@ -0,0 +1,145 @@
+package rs
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type ResourceServer interface {
+ IntrospectionURL() string
+ TokenEndpoint() string
+ HttpClient() *http.Client
+ AuthFn() (any, error)
+}
+
+type resourceServer struct {
+ issuer string
+ tokenURL string
+ introspectURL string
+ httpClient *http.Client
+ authFn func() (any, error)
+}
+
+func (r *resourceServer) IntrospectionURL() string {
+ return r.introspectURL
+}
+
+func (r *resourceServer) TokenEndpoint() string {
+ return r.tokenURL
+}
+
+func (r *resourceServer) HttpClient() *http.Client {
+ return r.httpClient
+}
+
+func (r *resourceServer) AuthFn() (any, error) {
+ return r.authFn()
+}
+
+func NewResourceServerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, option ...Option) (ResourceServer, error) {
+ authorizer := func() (any, error) {
+ return httphelper.AuthorizeBasic(clientID, clientSecret), nil
+ }
+ return newResourceServer(ctx, issuer, authorizer, option...)
+}
+
+func NewResourceServerJWTProfile(ctx context.Context, issuer, clientID, keyID string, key []byte, options ...Option) (ResourceServer, error) {
+ signer, err := client.NewSignerFromPrivateKeyByte(key, keyID)
+ if err != nil {
+ return nil, err
+ }
+ authorizer := func() (any, error) {
+ assertion, err := client.SignedJWTProfileAssertion(clientID, []string{issuer}, time.Hour, signer)
+ if err != nil {
+ return nil, err
+ }
+ return client.ClientAssertionFormAuthorization(assertion), nil
+ }
+ return newResourceServer(ctx, issuer, authorizer, options...)
+}
+
+func newResourceServer(ctx context.Context, issuer string, authorizer func() (any, error), options ...Option) (*resourceServer, error) {
+ rs := &resourceServer{
+ issuer: issuer,
+ httpClient: httphelper.DefaultHTTPClient,
+ }
+ for _, optFunc := range options {
+ optFunc(rs)
+ }
+ if rs.introspectURL == "" || rs.tokenURL == "" {
+ config, err := client.Discover(ctx, rs.issuer, rs.httpClient)
+ if err != nil {
+ return nil, err
+ }
+ if rs.tokenURL == "" {
+ rs.tokenURL = config.TokenEndpoint
+ }
+ if rs.introspectURL == "" {
+ rs.introspectURL = config.IntrospectionEndpoint
+ }
+ }
+ if rs.tokenURL == "" {
+ return nil, errors.New("tokenURL is empty: please provide with either `WithStaticEndpoints` or a discovery url")
+ }
+ rs.authFn = authorizer
+ return rs, nil
+}
+
+func NewResourceServerFromKeyFile(ctx context.Context, issuer, path string, options ...Option) (ResourceServer, error) {
+ c, err := client.ConfigFromKeyFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return NewResourceServerJWTProfile(ctx, issuer, c.ClientID, c.KeyID, []byte(c.Key), options...)
+}
+
+type Option func(*resourceServer)
+
+// WithClient provides the ability to set an http client to be used for the resource server
+func WithClient(client *http.Client) Option {
+ return func(server *resourceServer) {
+ server.httpClient = client
+ }
+}
+
+// WithStaticEndpoints provides the ability to set static token and introspect URL
+func WithStaticEndpoints(tokenURL, introspectURL string) Option {
+ return func(server *resourceServer) {
+ server.tokenURL = tokenURL
+ server.introspectURL = introspectURL
+ }
+}
+
+// Introspect calls the [RFC7662] Token Introspection
+// endpoint and returns the response in an instance of type R.
+// [*oidc.IntrospectionResponse] can be used as a good example, or use a custom type if type-safe
+// access to custom claims is needed.
+//
+// [RFC7662]: https://www.rfc-editor.org/rfc/rfc7662
+func Introspect[R any](ctx context.Context, rp ResourceServer, token string) (resp R, err error) {
+ ctx, span := client.Tracer.Start(ctx, "Introspect")
+ defer span.End()
+
+ if rp.IntrospectionURL() == "" {
+ return resp, errors.New("resource server: introspection URL is empty")
+ }
+ authFn, err := rp.AuthFn()
+ if err != nil {
+ return resp, err
+ }
+ req, err := httphelper.FormRequest(ctx, rp.IntrospectionURL(), &oidc.IntrospectionRequest{Token: token}, client.Encoder, authFn)
+ if err != nil {
+ return resp, err
+ }
+
+ if err := httphelper.HttpRequest(rp.HttpClient(), req, &resp); err != nil {
+ return resp, err
+ }
+ return resp, nil
+}
diff --git a/pkg/client/rs/resource_server_test.go b/pkg/client/rs/resource_server_test.go
new file mode 100644
index 0000000..afd7441
--- /dev/null
+++ b/pkg/client/rs/resource_server_test.go
@@ -0,0 +1,221 @@
+package rs
+
+import (
+ "context"
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewResourceServer(t *testing.T) {
+ type args struct {
+ issuer string
+ authorizer func() (any, error)
+ options []Option
+ }
+ type wantFields struct {
+ issuer string
+ tokenURL string
+ introspectURL string
+ authFn func() (any, error)
+ }
+ tests := []struct {
+ name string
+ args args
+ wantFields *wantFields
+ wantErr bool
+ }{
+ {
+ name: "spotify-full-discovery",
+ args: args{
+ issuer: "https://accounts.spotify.com",
+ authorizer: nil,
+ options: []Option{},
+ },
+ wantFields: &wantFields{
+ issuer: "https://accounts.spotify.com",
+ tokenURL: "https://accounts.spotify.com/api/token",
+ introspectURL: "",
+ authFn: nil,
+ },
+ wantErr: false,
+ },
+ {
+ name: "spotify-with-static-tokenurl",
+ args: args{
+ issuer: "https://accounts.spotify.com",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "https://some.host/token-url",
+ "",
+ ),
+ },
+ },
+ wantFields: &wantFields{
+ issuer: "https://accounts.spotify.com",
+ tokenURL: "https://some.host/token-url",
+ introspectURL: "",
+ authFn: nil,
+ },
+ wantErr: false,
+ },
+ {
+ name: "spotify-with-static-introspecturl",
+ args: args{
+ issuer: "https://accounts.spotify.com",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "",
+ "https://some.host/instrospect-url",
+ ),
+ },
+ },
+ wantFields: &wantFields{
+ issuer: "https://accounts.spotify.com",
+ tokenURL: "https://accounts.spotify.com/api/token",
+ introspectURL: "https://some.host/instrospect-url",
+ authFn: nil,
+ },
+ wantErr: false,
+ },
+ {
+ name: "spotify-with-all-static-endpoints",
+ args: args{
+ issuer: "https://accounts.spotify.com",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "https://some.host/token-url",
+ "https://some.host/instrospect-url",
+ ),
+ },
+ },
+ wantFields: &wantFields{
+ issuer: "https://accounts.spotify.com",
+ tokenURL: "https://some.host/token-url",
+ introspectURL: "https://some.host/instrospect-url",
+ authFn: nil,
+ },
+ wantErr: false,
+ },
+ {
+ name: "bad-discovery",
+ args: args{
+ issuer: "https://127.0.0.1:65535",
+ authorizer: nil,
+ options: []Option{},
+ },
+ wantFields: nil,
+ wantErr: true,
+ },
+ {
+ name: "bad-discovery-with-static-tokenurl",
+ args: args{
+ issuer: "https://127.0.0.1:65535",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "https://some.host/token-url",
+ "",
+ ),
+ },
+ },
+ wantFields: nil,
+ wantErr: true,
+ },
+ {
+ name: "bad-discovery-with-static-introspecturl",
+ args: args{
+ issuer: "https://127.0.0.1:65535",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "",
+ "https://some.host/instrospect-url",
+ ),
+ },
+ },
+ wantFields: nil,
+ wantErr: true,
+ },
+ {
+ name: "bad-discovery-with-all-static-endpoints",
+ args: args{
+ issuer: "https://127.0.0.1:65535",
+ authorizer: nil,
+ options: []Option{
+ WithStaticEndpoints(
+ "https://some.host/token-url",
+ "https://some.host/instrospect-url",
+ ),
+ },
+ },
+ wantFields: &wantFields{
+ issuer: "https://127.0.0.1:65535",
+ tokenURL: "https://some.host/token-url",
+ introspectURL: "https://some.host/instrospect-url",
+ authFn: nil,
+ },
+ wantErr: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := newResourceServer(context.Background(), tt.args.issuer, tt.args.authorizer, tt.args.options...)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ if tt.wantFields == nil {
+ return
+ }
+ assert.Equal(t, tt.wantFields.issuer, got.issuer)
+ assert.Equal(t, tt.wantFields.tokenURL, got.tokenURL)
+ assert.Equal(t, tt.wantFields.introspectURL, got.introspectURL)
+ })
+ }
+}
+
+func TestIntrospect(t *testing.T) {
+ type args struct {
+ ctx context.Context
+ rp ResourceServer
+ token string
+ }
+ rp, err := newResourceServer(
+ context.Background(),
+ "https://accounts.spotify.com",
+ nil,
+ )
+ require.NoError(t, err)
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ name: "missing-introspect-url",
+ args: args{
+ ctx: context.Background(),
+ rp: rp,
+ token: "my-token",
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := Introspect[*oidc.IntrospectionResponse](tt.args.ctx, tt.args.rp, tt.args.token)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ })
+ }
+}
diff --git a/pkg/client/tokenexchange/tokenexchange.go b/pkg/client/tokenexchange/tokenexchange.go
new file mode 100644
index 0000000..9cc1328
--- /dev/null
+++ b/pkg/client/tokenexchange/tokenexchange.go
@@ -0,0 +1,145 @@
+package tokenexchange
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/go-jose/go-jose/v4"
+)
+
+type TokenExchanger interface {
+ TokenEndpoint() string
+ HttpClient() *http.Client
+ AuthFn() (any, error)
+}
+
+type OAuthTokenExchange struct {
+ httpClient *http.Client
+ tokenEndpoint string
+ authFn func() (any, error)
+}
+
+func NewTokenExchanger(ctx context.Context, issuer string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
+ return newOAuthTokenExchange(ctx, issuer, nil, options...)
+}
+
+func NewTokenExchangerClientCredentials(ctx context.Context, issuer, clientID, clientSecret string, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
+ authorizer := func() (any, error) {
+ return httphelper.AuthorizeBasic(clientID, clientSecret), nil
+ }
+ return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
+}
+
+func NewTokenExchangerJWTProfile(ctx context.Context, issuer, clientID string, signer jose.Signer, options ...func(source *OAuthTokenExchange)) (TokenExchanger, error) {
+ authorizer := func() (any, error) {
+ assertion, err := client.SignedJWTProfileAssertion(clientID, []string{issuer}, time.Hour, signer)
+ if err != nil {
+ return nil, err
+ }
+ return client.ClientAssertionFormAuthorization(assertion), nil
+ }
+ return newOAuthTokenExchange(ctx, issuer, authorizer, options...)
+}
+
+func newOAuthTokenExchange(ctx context.Context, issuer string, authorizer func() (any, error), options ...func(source *OAuthTokenExchange)) (*OAuthTokenExchange, error) {
+ te := &OAuthTokenExchange{
+ httpClient: httphelper.DefaultHTTPClient,
+ }
+ for _, opt := range options {
+ opt(te)
+ }
+
+ if te.tokenEndpoint == "" {
+ config, err := client.Discover(ctx, issuer, te.httpClient)
+ if err != nil {
+ return nil, err
+ }
+
+ te.tokenEndpoint = config.TokenEndpoint
+ }
+
+ if te.tokenEndpoint == "" {
+ return nil, errors.New("tokenURL is empty: please provide with either `WithStaticTokenEndpoint` or a discovery url")
+ }
+
+ te.authFn = authorizer
+
+ return te, nil
+}
+
+func WithHTTPClient(client *http.Client) func(*OAuthTokenExchange) {
+ return func(source *OAuthTokenExchange) {
+ source.httpClient = client
+ }
+}
+
+func WithStaticTokenEndpoint(issuer, tokenEndpoint string) func(*OAuthTokenExchange) {
+ return func(source *OAuthTokenExchange) {
+ source.tokenEndpoint = tokenEndpoint
+ }
+}
+
+func (te *OAuthTokenExchange) TokenEndpoint() string {
+ return te.tokenEndpoint
+}
+
+func (te *OAuthTokenExchange) HttpClient() *http.Client {
+ return te.httpClient
+}
+
+func (te *OAuthTokenExchange) AuthFn() (any, error) {
+ if te.authFn != nil {
+ return te.authFn()
+ }
+
+ return nil, nil
+}
+
+// ExchangeToken sends a token exchange request (rfc 8693) to te's token endpoint.
+// SubjectToken and SubjectTokenType are required parameters.
+func ExchangeToken(
+ ctx context.Context,
+ te TokenExchanger,
+ SubjectToken string,
+ SubjectTokenType oidc.TokenType,
+ ActorToken string,
+ ActorTokenType oidc.TokenType,
+ Resource []string,
+ Audience []string,
+ Scopes []string,
+ RequestedTokenType oidc.TokenType,
+) (*oidc.TokenExchangeResponse, error) {
+ ctx, span := client.Tracer.Start(ctx, "ExchangeToken")
+ defer span.End()
+
+ if SubjectToken == "" {
+ return nil, errors.New("empty subject_token")
+ }
+ if SubjectTokenType == "" {
+ return nil, errors.New("empty subject_token_type")
+ }
+
+ authFn, err := te.AuthFn()
+ if err != nil {
+ return nil, err
+ }
+
+ request := oidc.TokenExchangeRequest{
+ GrantType: oidc.GrantTypeTokenExchange,
+ SubjectToken: SubjectToken,
+ SubjectTokenType: SubjectTokenType,
+ ActorToken: ActorToken,
+ ActorTokenType: ActorTokenType,
+ Resource: Resource,
+ Audience: Audience,
+ Scopes: Scopes,
+ RequestedTokenType: RequestedTokenType,
+ }
+
+ return client.CallTokenExchangeEndpoint(ctx, request, authFn, te)
+}
diff --git a/pkg/utils/crypto.go b/pkg/crypto/crypto.go
similarity index 82%
rename from pkg/utils/crypto.go
rename to pkg/crypto/crypto.go
index 05acb75..109fa0b 100644
--- a/pkg/utils/crypto.go
+++ b/pkg/crypto/crypto.go
@@ -1,4 +1,4 @@
-package utils
+package crypto
import (
"crypto/aes"
@@ -9,17 +9,18 @@ import (
"io"
)
+var ErrCipherTextBlockSize = errors.New("ciphertext block size is too short")
+
func EncryptAES(data string, key string) (string, error) {
encrypted, err := EncryptBytesAES([]byte(data), key)
if err != nil {
return "", err
}
- return base64.URLEncoding.EncodeToString(encrypted), nil
+ return base64.RawURLEncoding.EncodeToString(encrypted), nil
}
func EncryptBytesAES(plainText []byte, key string) ([]byte, error) {
-
block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
@@ -38,9 +39,9 @@ func EncryptBytesAES(plainText []byte, key string) ([]byte, error) {
}
func DecryptAES(data string, key string) (string, error) {
- text, err := base64.URLEncoding.DecodeString(data)
+ text, err := base64.RawURLEncoding.DecodeString(data)
if err != nil {
- return "", nil
+ return "", err
}
decrypted, err := DecryptBytesAES(text, key)
if err != nil {
@@ -50,15 +51,13 @@ func DecryptAES(data string, key string) (string, error) {
}
func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) {
-
block, err := aes.NewCipher([]byte(key))
if err != nil {
return nil, err
}
if len(cipherText) < aes.BlockSize {
- err = errors.New("Ciphertext block size is too short!")
- return nil, err
+ return nil, ErrCipherTextBlockSize
}
iv := cipherText[:aes.BlockSize]
cipherText = cipherText[aes.BlockSize:]
diff --git a/pkg/crypto/hash.go b/pkg/crypto/hash.go
new file mode 100644
index 0000000..14acdee
--- /dev/null
+++ b/pkg/crypto/hash.go
@@ -0,0 +1,49 @@
+package crypto
+
+import (
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "hash"
+
+ jose "github.com/go-jose/go-jose/v4"
+)
+
+var ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")
+
+func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
+ switch sigAlgorithm {
+ case jose.RS256, jose.ES256, jose.PS256:
+ return sha256.New(), nil
+ case jose.RS384, jose.ES384, jose.PS384:
+ return sha512.New384(), nil
+ case jose.RS512, jose.ES512, jose.PS512:
+ return sha512.New(), nil
+
+ // There is no published spec for this yet, but we have confirmation it will get published.
+ // There is consensus here: https://bitbucket.org/openid/connect/issues/1125/_hash-algorithm-for-eddsa-id-tokens
+ // Currently Go and go-jose only supports the ed25519 curve key for EdDSA, so we can safely assume sha512 here.
+ // It is unlikely ed448 will ever be supported: https://github.com/golang/go/issues/29390
+ case jose.EdDSA:
+ return sha512.New(), nil
+
+ default:
+ return nil, fmt.Errorf("%w: %q", ErrUnsupportedAlgorithm, sigAlgorithm)
+ }
+}
+
+func HashString(hash hash.Hash, s string, firstHalf bool) string {
+ if hash == nil {
+ return s
+ }
+ //nolint:errcheck
+ hash.Write([]byte(s))
+ size := hash.Size()
+ if firstHalf {
+ size = size / 2
+ }
+ sum := hash.Sum(nil)[:size]
+ return base64.RawURLEncoding.EncodeToString(sum)
+}
diff --git a/pkg/crypto/key.go b/pkg/crypto/key.go
new file mode 100644
index 0000000..12bca28
--- /dev/null
+++ b/pkg/crypto/key.go
@@ -0,0 +1,45 @@
+package crypto
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+
+ "github.com/go-jose/go-jose/v4"
+)
+
+var (
+ ErrPEMDecode = errors.New("PEM decode failed")
+ ErrUnsupportedFormat = errors.New("key is neither in PKCS#1 nor PKCS#8 format")
+ ErrUnsupportedPrivateKey = errors.New("unsupported key type, must be RSA, ECDSA or ED25519 private key")
+)
+
+func BytesToPrivateKey(b []byte) (crypto.PublicKey, jose.SignatureAlgorithm, error) {
+ block, _ := pem.Decode(b)
+ if block == nil {
+ return nil, "", ErrPEMDecode
+ }
+
+ privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err == nil {
+ return privateKey, jose.RS256, nil
+ }
+ key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
+ if err != nil {
+ return nil, "", ErrUnsupportedFormat
+ }
+ switch privateKey := key.(type) {
+ case *rsa.PrivateKey:
+ return privateKey, jose.RS256, nil
+ case ed25519.PrivateKey:
+ return privateKey, jose.EdDSA, nil
+ case *ecdsa.PrivateKey:
+ return privateKey, jose.ES256, nil
+ default:
+ return nil, "", ErrUnsupportedPrivateKey
+ }
+}
diff --git a/pkg/crypto/key_test.go b/pkg/crypto/key_test.go
new file mode 100644
index 0000000..a6fa493
--- /dev/null
+++ b/pkg/crypto/key_test.go
@@ -0,0 +1,134 @@
+package crypto_test
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/rsa"
+ "testing"
+
+ "github.com/go-jose/go-jose/v4"
+ "github.com/stretchr/testify/assert"
+
+ zcrypto "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
+)
+
+func TestBytesToPrivateKey(t *testing.T) {
+ type args struct {
+ key []byte
+ }
+ type want struct {
+ key crypto.Signer
+ algorithm jose.SignatureAlgorithm
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ want want
+ }{
+ {
+ name: "PEMDecodeError",
+ args: args{
+ key: []byte("The non-PEM sequence"),
+ },
+ want: want{
+ err: zcrypto.ErrPEMDecode,
+ },
+ },
+ {
+ name: "PKCS#1 RSA",
+ args: args{
+ key: []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu
+KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm
+o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k
+TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7
+9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy
+v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs
+/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00
+-----END RSA PRIVATE KEY-----`),
+ },
+ want: want{
+ key: &rsa.PrivateKey{},
+ algorithm: jose.RS256,
+ err: nil,
+ },
+ },
+ {
+ name: "PKCS#8 RSA",
+ args: args{
+ key: []byte(`-----BEGIN PRIVATE KEY-----
+MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCfaDB7pK/fmP/I
+7IusSK8lTCBnPZghqIbVLt2QHYAMoEF1CaF4F4rxo2vl1Mt8gwsq4T3osQFZMvnL
+YHb7KNyUoJgTjLxJQADv2u4Q3U38heAzK5Tp4ry4MCnuyJIqAPK1GiruwEq4zQrx
++WzVix8otO37SuW9tzklqlNGMiAYBL0TBKHvS5XMbjP1idBMB8erMz29w/TVQnEB
+Kj0vCdZjrbVPKygptt5kcSrL5f4xCZwU+ufz7cp0GLwpRMJ+shG9YJJFBxb0itPF
+sy51vAyEtdBC7jgAU96ZVeQ06nryDq1D2EpoVMElqNyL46Jo3lnKbGquGKzXzQYU
+BN32/scDAgMBAAECggEBAJE/mo3PLgILo2YtQ8ekIxNVHmF0Gl7w9IrjvTdH6hmX
+HI3MTLjkmtI7GmG9V/0IWvCjdInGX3grnrjWGRQZ04QKIQgPQLFuBGyJjEsJm7nx
+MqztlS7YTyV1nX/aenSTkJO8WEpcJLnm+4YoxCaAMdAhrIdBY71OamALpv1bRysa
+FaiCGcemT2yqZn0GqIS8O26Tz5zIqrTN2G1eSmgh7DG+7FoddMz35cute8R10xUG
+hF5YU+6fcXiRQ/Kh7nlxelPGqdZFPMk7LpVHzkQKwdJ+N0P23lPDIfNsvpG1n0OP
+3g5km7gHSrSU2yZ3eFl6DB9x1IFNS9BaQQuSxYJtKwECgYEA1C8jjzpXZDLvlYsV
+2jlMzkrbsIrX2dzblVrNsPs2jRbjYU8mg2DUDO6lOhtxHfqZG6sO+gmWi/zvoy9l
+yolGbXe1Jqx66p9fznIcecSwar8+ACa356Wk74Nt1PlBOfCMqaJnYLOLaFJa29Vy
+u5ClZVzKd5AVXl7yFVd4XfLv/WECgYEAwFMMtFoasdF92c0d31rZ1uoPOtFz6xq6
+uQggdm5zzkhnfwUAGqppS/u1CHcJ7T/74++jLbFTsaohGr4jEzWSGvJpomEUChy3
+r25YofMclUhJ5pCEStsLtqiCR1Am6LlI8HMdBEP1QDgEC5q8bQW4+UHuew1E1zxz
+osZOhe09WuMCgYEA0G9aFCnwjUqIFjQiDFP7gi8BLqTFs4uE3Wvs4W11whV42i+B
+ms90nxuTjchFT3jMDOT1+mOO0wdudLRr3xEI8SIF/u6ydGaJG+j21huEXehtxIJE
+aDdNFcfbDbqo+3y1ATK7MMBPMvSrsoY0hdJq127WqasNgr3sO1DIuima3SECgYEA
+nkM5TyhekzlbIOHD1UsDu/D7+2DkzPE/+oePfyXBMl0unb3VqhvVbmuBO6gJiSx/
+8b//PdiQkMD5YPJaFrKcuoQFHVRZk0CyfzCEyzAts0K7XXpLAvZiGztriZeRjSz7
+srJnjF0H8oKmAY6hw+1Tm/n/b08p+RyL48TgVSE2vhUCgYA3BWpkD4PlCcn/FZsq
+OrLFyFXI6jIaxskFtsRW1IxxIlAdZmxfB26P/2gx6VjLdxJI/RRPkJyEN2dP7CbR
+BDjb565dy1O9D6+UrY70Iuwjz+OcALRBBGTaiF2pLn6IhSzNI2sy/tXX8q8dBlg9
+OFCrqT/emes3KytTPfa5NZtYeQ==
+-----END PRIVATE KEY-----`),
+ },
+ want: want{
+ key: &rsa.PrivateKey{},
+ algorithm: jose.RS256,
+ err: nil,
+ },
+ },
+ {
+ name: "PKCS#8 ECDSA",
+ args: args{
+ key: []byte(`-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgwwOZSU4GlP7ps/Wp
+V6o0qRwxultdfYo/uUuj48QZjSuhRANCAATMiI2Han+ABKmrk5CNlxRAGC61w4d3
+G4TAeuBpyzqJ7x/6NjCxoQzJzZHtNjIfjVATI59XFZWF59GhtSZbShAr
+-----END PRIVATE KEY-----`),
+ },
+ want: want{
+ key: &ecdsa.PrivateKey{},
+ algorithm: jose.ES256,
+ err: nil,
+ },
+ },
+ {
+ name: "PKCS#8 ED25519",
+ args: args{
+ key: []byte(`-----BEGIN PRIVATE KEY-----
+MC4CAQAwBQYDK2VwBCIEIHu6ZtDsjjauMasBxnS9Fg87UJwKfcT/oiq6S0ktbky8
+-----END PRIVATE KEY-----`),
+ },
+ want: want{
+ key: ed25519.PrivateKey{},
+ algorithm: jose.EdDSA,
+ err: nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ key, algorithm, err := zcrypto.BytesToPrivateKey(tt.args.key)
+ assert.IsType(t, tt.want.key, key)
+ assert.Equal(t, tt.want.algorithm, algorithm)
+ assert.ErrorIs(t, tt.want.err, err)
+ })
+
+ }
+}
diff --git a/pkg/crypto/sign.go b/pkg/crypto/sign.go
new file mode 100644
index 0000000..937a846
--- /dev/null
+++ b/pkg/crypto/sign.go
@@ -0,0 +1,27 @@
+package crypto
+
+import (
+ "encoding/json"
+ "errors"
+
+ jose "github.com/go-jose/go-jose/v4"
+)
+
+func Sign(object any, signer jose.Signer) (string, error) {
+ payload, err := json.Marshal(object)
+ if err != nil {
+ return "", err
+ }
+ return SignPayload(payload, signer)
+}
+
+func SignPayload(payload []byte, signer jose.Signer) (string, error) {
+ if signer == nil {
+ return "", errors.New("missing signer")
+ }
+ result, err := signer.Sign(payload)
+ if err != nil {
+ return "", err
+ }
+ return result.CompactSerialize()
+}
diff --git a/pkg/utils/cookie.go b/pkg/http/cookie.go
similarity index 91%
rename from pkg/utils/cookie.go
rename to pkg/http/cookie.go
index 9e73e08..1ebc9e2 100644
--- a/pkg/utils/cookie.go
+++ b/pkg/http/cookie.go
@@ -1,4 +1,4 @@
-package utils
+package http
import (
"errors"
@@ -13,6 +13,7 @@ type CookieHandler struct {
sameSite http.SameSite
maxAge int
domain string
+ path string
}
func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *CookieHandler {
@@ -20,6 +21,7 @@ func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *Coo
securecookie: securecookie.New(hashKey, encryptKey),
secureOnly: true,
sameSite: http.SameSiteLaxMode,
+ path: "/",
}
for _, opt := range opts {
@@ -55,6 +57,12 @@ func WithDomain(domain string) CookieHandlerOpt {
}
}
+func WithPath(path string) CookieHandlerOpt {
+ return func(c *CookieHandler) {
+ c.path = path
+ }
+}
+
func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
cookie, err := r.Cookie(name)
if err != nil {
@@ -87,7 +95,7 @@ func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) err
Name: name,
Value: encoded,
Domain: c.domain,
- Path: "/",
+ Path: c.path,
MaxAge: c.maxAge,
HttpOnly: true,
Secure: c.secureOnly,
@@ -101,7 +109,7 @@ func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
Name: name,
Value: "",
Domain: c.domain,
- Path: "/",
+ Path: c.path,
MaxAge: -1,
HttpOnly: true,
Secure: c.secureOnly,
diff --git a/pkg/http/http.go b/pkg/http/http.go
new file mode 100644
index 0000000..aa0ff6f
--- /dev/null
+++ b/pkg/http/http.go
@@ -0,0 +1,112 @@
+package http
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+var DefaultHTTPClient = &http.Client{
+ Timeout: 30 * time.Second,
+}
+
+type Decoder interface {
+ Decode(dst any, src map[string][]string) error
+}
+
+type Encoder interface {
+ Encode(src any, dst map[string][]string) error
+}
+
+type FormAuthorization func(url.Values)
+type RequestAuthorization func(*http.Request)
+
+func AuthorizeBasic(user, password string) RequestAuthorization {
+ return func(req *http.Request) {
+ req.SetBasicAuth(url.QueryEscape(user), url.QueryEscape(password))
+ }
+}
+
+func FormRequest(ctx context.Context, endpoint string, request any, encoder Encoder, authFn any) (*http.Request, error) {
+ form := url.Values{}
+ if err := encoder.Encode(request, form); err != nil {
+ return nil, err
+ }
+ if fn, ok := authFn.(FormAuthorization); ok {
+ fn(form)
+ }
+ body := strings.NewReader(form.Encode())
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body)
+ if err != nil {
+ return nil, err
+ }
+ if fn, ok := authFn.(RequestAuthorization); ok {
+ fn(req)
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ return req, nil
+}
+
+func HttpRequest(client *http.Client, req *http.Request, response any) error {
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("unable to read response body: %v", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ var oidcErr oidc.Error
+ err = json.Unmarshal(body, &oidcErr)
+ if err != nil || oidcErr.ErrorType == "" {
+ return fmt.Errorf("http status not ok: %s %s", resp.Status, body)
+ }
+ return &oidcErr
+ }
+
+ err = json.Unmarshal(body, response)
+ if err != nil {
+ return fmt.Errorf("failed to unmarshal response: %v %s", err, body)
+ }
+ return nil
+}
+
+func URLEncodeParams(resp any, encoder Encoder) (url.Values, error) {
+ values := make(map[string][]string)
+ err := encoder.Encode(resp, values)
+ if err != nil {
+ return nil, err
+ }
+ return values, nil
+}
+
+func StartServer(ctx context.Context, port string) {
+ server := &http.Server{Addr: port}
+ go func() {
+ if err := server.ListenAndServe(); err != http.ErrServerClosed {
+ log.Fatalf("ListenAndServe(): %v", err)
+ }
+ }()
+
+ go func() {
+ <-ctx.Done()
+ ctxShutdown, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelShutdown()
+ err := server.Shutdown(ctxShutdown)
+ if err != nil {
+ log.Fatalf("Shutdown(): %v", err)
+ }
+ }()
+}
diff --git a/pkg/http/marshal.go b/pkg/http/marshal.go
new file mode 100644
index 0000000..71ed2c2
--- /dev/null
+++ b/pkg/http/marshal.go
@@ -0,0 +1,45 @@
+package http
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "reflect"
+)
+
+func MarshalJSON(w http.ResponseWriter, i any) {
+ MarshalJSONWithStatus(w, i, http.StatusOK)
+}
+
+func MarshalJSONWithStatus(w http.ResponseWriter, i any, status int) {
+ w.Header().Set("content-type", "application/json")
+ w.WriteHeader(status)
+ if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) {
+ return
+ }
+ err := json.NewEncoder(w).Encode(i)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+func ConcatenateJSON(first, second []byte) ([]byte, error) {
+ if !bytes.HasSuffix(first, []byte{'}'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", first)
+ }
+ if !bytes.HasPrefix(second, []byte{'{'}) {
+ return nil, fmt.Errorf("jws: invalid JSON %s", second)
+ }
+ // check empty
+ if len(first) == 2 {
+ return second, nil
+ }
+ if len(second) == 2 {
+ return first, nil
+ }
+
+ first[len(first)-1] = ','
+ first = append(first, second[1:]...)
+ return first, nil
+}
diff --git a/pkg/http/marshal_test.go b/pkg/http/marshal_test.go
new file mode 100644
index 0000000..dcc7fdd
--- /dev/null
+++ b/pkg/http/marshal_test.go
@@ -0,0 +1,156 @@
+package http
+
+import (
+ "bytes"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestConcatenateJSON(t *testing.T) {
+ type args struct {
+ first []byte
+ second []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ want []byte
+ wantErr bool
+ }{
+ {
+ "invalid first part, error",
+ args{
+ []byte(`invalid`),
+ []byte(`{"some": "thing"}`),
+ },
+ nil,
+ true,
+ },
+ {
+ "invalid second part, error",
+ args{
+ []byte(`{"some": "thing"}`),
+ []byte(`invalid`),
+ },
+ nil,
+ true,
+ },
+ {
+ "both valid, merged",
+ args{
+ []byte(`{"some": "thing"}`),
+ []byte(`{"another": "thing"}`),
+ },
+
+ []byte(`{"some": "thing","another": "thing"}`),
+ false,
+ },
+ {
+ "first empty",
+ args{
+ []byte(`{}`),
+ []byte(`{"some": "thing"}`),
+ },
+
+ []byte(`{"some": "thing"}`),
+ false,
+ },
+ {
+ "second empty",
+ args{
+ []byte(`{"some": "thing"}`),
+ []byte(`{}`),
+ },
+
+ []byte(`{"some": "thing"}`),
+ false,
+ },
+ {
+ "both empty",
+ args{
+ []byte(`{}`),
+ []byte(`{}`),
+ },
+
+ []byte(`{}`),
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := ConcatenateJSON(tt.args.first, tt.args.second)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if !bytes.Equal(got, tt.want) {
+ t.Errorf("ConcatenateJSON() got = %v, want %v", string(got), tt.want)
+ }
+ })
+ }
+}
+
+func TestMarshalJSONWithStatus(t *testing.T) {
+ type args struct {
+ i any
+ status int
+ }
+ type res struct {
+ statusCode int
+ body string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "empty ok",
+ args{
+ nil,
+ 200,
+ },
+ res{
+ 200,
+ "",
+ },
+ },
+ {
+ "string ok",
+ args{
+ "ok",
+ 200,
+ },
+ res{
+ 200,
+ `"ok"
+`,
+ },
+ },
+ {
+ "struct ok",
+ args{
+ struct {
+ Test string `json:"test"`
+ }{"ok"},
+ 200,
+ },
+ res{
+ 200,
+ `{"test":"ok"}
+`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ MarshalJSONWithStatus(w, tt.args.i, tt.args.status)
+ assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
+ assert.Equal(t, "application/json", w.Header().Get("content-type"))
+ assert.Equal(t, tt.res.body, w.Body.String())
+ })
+ }
+}
diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go
index 02c5603..fa37dbf 100644
--- a/pkg/oidc/authorization.go
+++ b/pkg/oidc/authorization.go
@@ -1,17 +1,44 @@
package oidc
import (
- "errors"
- "strings"
-
- "golang.org/x/text/language"
+ "log/slog"
)
const (
+ // ScopeOpenID defines the scope `openid`
+ // OpenID Connect requests MUST contain the `openid` scope value
ScopeOpenID = "openid"
- ResponseTypeCode ResponseType = "code"
- ResponseTypeIDToken ResponseType = "id_token token"
+ // ScopeProfile defines the scope `profile`
+ // This (optional) scope value requests access to the End-User's default profile Claims,
+ // which are: name, family_name, given_name, middle_name, nickname, preferred_username,
+ // profile, picture, website, gender, birthdate, zoneinfo, locale, and updated_at.
+ ScopeProfile = "profile"
+
+ // ScopeEmail defines the scope `email`
+ // This (optional) scope value requests access to the email and email_verified Claims.
+ ScopeEmail = "email"
+
+ // ScopeAddress defines the scope `address`
+ // This (optional) scope value requests access to the address Claim.
+ ScopeAddress = "address"
+
+ // ScopePhone defines the scope `phone`
+ // This (optional) scope value requests access to the phone_number and phone_number_verified Claims.
+ ScopePhone = "phone"
+
+ // ScopeOfflineAccess defines the scope `offline_access`
+ // This (optional) scope value requests that an OAuth 2.0 Refresh Token be issued that can be used to obtain an Access Token
+ // that grants access to the End-User's UserInfo Endpoint even when the End-User is not present (not logged in).
+ ScopeOfflineAccess = "offline_access"
+
+ // ResponseTypeCode for the Authorization Code Flow returning a code from the Authorization Server
+ ResponseTypeCode ResponseType = "code"
+
+ // ResponseTypeIDToken for the Implicit Flow returning id and access tokens directly from the Authorization Server
+ ResponseTypeIDToken ResponseType = "id_token token"
+
+ // ResponseTypeIDTokenOnly for the Implicit Flow returning only id token directly from the Authorization Server
ResponseTypeIDTokenOnly ResponseType = "id_token"
DisplayPage Display = "page"
@@ -19,133 +46,76 @@ const (
DisplayTouch Display = "touch"
DisplayWAP Display = "wap"
- PromptNone Prompt = "none"
- PromptLogin Prompt = "login"
- PromptConsent Prompt = "consent"
- PromptSelectAccount Prompt = "select_account"
+ ResponseModeQuery ResponseMode = "query"
+ ResponseModeFragment ResponseMode = "fragment"
+ ResponseModeFormPost ResponseMode = "form_post"
- GrantTypeCode GrantType = "authorization_code"
+ // PromptNone (`none`) disallows the Authorization Server to display any authentication or consent user interface pages.
+ // An error (login_required, interaction_required, ...) will be returned if the user is not already authenticated or consent is needed
+ PromptNone = "none"
- BearerToken = "Bearer"
+ // PromptLogin (`login`) directs the Authorization Server to prompt the End-User for reauthentication.
+ PromptLogin = "login"
+
+ // PromptConsent (`consent`) directs the Authorization Server to prompt the End-User for consent (of sharing information).
+ PromptConsent = "consent"
+
+ // PromptSelectAccount (`select_account `) directs the Authorization Server to prompt the End-User to select a user account (to enable multi user / session switching)
+ PromptSelectAccount = "select_account"
)
-var displayValues = map[string]Display{
- "page": DisplayPage,
- "popup": DisplayPopup,
- "touch": DisplayTouch,
- "wap": DisplayWAP,
-}
-
-//AuthRequest according to:
-//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
-//
+// AuthRequest according to:
+// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
type AuthRequest struct {
- ID string
- Scopes Scopes `schema:"scope"`
- ResponseType ResponseType `schema:"response_type"`
- ClientID string `schema:"client_id"`
- RedirectURI string `schema:"redirect_uri"` //TODO: type
+ Scopes SpaceDelimitedArray `json:"scope" schema:"scope"`
+ ResponseType ResponseType `json:"response_type" schema:"response_type"`
+ ClientID string `json:"client_id" schema:"client_id"`
+ RedirectURI string `json:"redirect_uri" schema:"redirect_uri"`
- State string `schema:"state"`
+ State string `json:"state" schema:"state"`
+ Nonce string `json:"nonce" schema:"nonce"`
- // ResponseMode TODO: ?
+ ResponseMode ResponseMode `json:"response_mode" schema:"response_mode"`
+ Display Display `json:"display" schema:"display"`
+ Prompt SpaceDelimitedArray `json:"prompt" schema:"prompt"`
+ MaxAge *uint `json:"max_age" schema:"max_age"`
+ UILocales Locales `json:"ui_locales" schema:"ui_locales"`
+ IDTokenHint string `json:"id_token_hint" schema:"id_token_hint"`
+ LoginHint string `json:"login_hint" schema:"login_hint"`
+ ACRValues SpaceDelimitedArray `json:"acr_values" schema:"acr_values"`
- Nonce string `schema:"nonce"`
- Display Display `schema:"display"`
- Prompt Prompt `schema:"prompt"`
- MaxAge uint32 `schema:"max_age"`
- UILocales Locales `schema:"ui_locales"`
- IDTokenHint string `schema:"id_token_hint"`
- LoginHint string `schema:"login_hint"`
- ACRValues []string `schema:"acr_values"`
+ CodeChallenge string `json:"code_challenge" schema:"code_challenge"`
+ CodeChallengeMethod CodeChallengeMethod `json:"code_challenge_method" schema:"code_challenge_method"`
- CodeChallenge string `schema:"code_challenge"`
- CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
+ // RequestParam enables OIDC requests to be passed in a single, self-contained parameter (as JWT, called Request Object)
+ RequestParam string `schema:"request"`
}
+func (a *AuthRequest) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.Any("scopes", a.Scopes),
+ slog.String("response_type", string(a.ResponseType)),
+ slog.String("client_id", a.ClientID),
+ slog.String("redirect_uri", a.RedirectURI),
+ )
+}
+
+// GetRedirectURI returns the redirect_uri value for the ErrAuthRequest interface
func (a *AuthRequest) GetRedirectURI() string {
return a.RedirectURI
}
+
+// GetResponseType returns the response_type value for the ErrAuthRequest interface
func (a *AuthRequest) GetResponseType() ResponseType {
return a.ResponseType
}
+
+// GetState returns the optional state value for the ErrAuthRequest interface
func (a *AuthRequest) GetState() string {
return a.State
}
-type TokenRequest interface {
- // GrantType GrantType `schema:"grant_type"`
- GrantType() GrantType
+// GetResponseMode returns the optional ResponseMode
+func (a *AuthRequest) GetResponseMode() ResponseMode {
+ return a.ResponseMode
}
-
-type TokenRequestType GrantType
-
-type AccessTokenRequest struct {
- Code string `schema:"code"`
- RedirectURI string `schema:"redirect_uri"`
- ClientID string `schema:"client_id"`
- ClientSecret string `schema:"client_secret"`
- CodeVerifier string `schema:"code_verifier"`
-}
-
-func (a *AccessTokenRequest) GrantType() GrantType {
- return GrantTypeCode
-}
-
-type AccessTokenResponse struct {
- AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
- TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
- RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
- ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
- IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
-}
-
-type TokenExchangeRequest struct {
- subjectToken string `schema:"subject_token"`
- subjectTokenType string `schema:"subject_token_type"`
- actorToken string `schema:"actor_token"`
- actorTokenType string `schema:"actor_token_type"`
- resource []string `schema:"resource"`
- audience []string `schema:"audience"`
- Scope []string `schema:"scope"`
- requestedTokenType string `schema:"requested_token_type"`
-}
-
-type Scopes []string
-
-func (s *Scopes) UnmarshalText(text []byte) error {
- scopes := strings.Split(string(text), " ")
- *s = Scopes(scopes)
- return nil
-}
-
-type ResponseType string
-
-type Display string
-
-func (d *Display) UnmarshalText(text []byte) error {
- var ok bool
- display := string(text)
- *d, ok = displayValues[display]
- if !ok {
- return errors.New("")
- }
- return nil
-}
-
-type Prompt string
-
-type Locales []language.Tag
-
-func (l *Locales) UnmarshalText(text []byte) error {
- locales := strings.Split(string(text), " ")
- for _, locale := range locales {
- tag, err := language.Parse(locale)
- if err == nil && !tag.IsRoot() {
- *l = append(*l, tag)
- }
- }
- return nil
-}
-
-type GrantType string
diff --git a/pkg/oidc/authorization_test.go b/pkg/oidc/authorization_test.go
new file mode 100644
index 0000000..1446efa
--- /dev/null
+++ b/pkg/oidc/authorization_test.go
@@ -0,0 +1,27 @@
+//go:build go1.20
+
+package oidc
+
+import (
+ "log/slog"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestAuthRequest_LogValue(t *testing.T) {
+ a := &AuthRequest{
+ Scopes: SpaceDelimitedArray{"a", "b"},
+ ResponseType: "respType",
+ ClientID: "123",
+ RedirectURI: "http://example.com/callback",
+ }
+ want := slog.GroupValue(
+ slog.Any("scopes", SpaceDelimitedArray{"a", "b"}),
+ slog.String("response_type", "respType"),
+ slog.String("client_id", "123"),
+ slog.String("redirect_uri", "http://example.com/callback"),
+ )
+ got := a.LogValue()
+ assert.Equal(t, want, got)
+}
diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go
index 44a0499..0c593df 100644
--- a/pkg/oidc/code_challenge.go
+++ b/pkg/oidc/code_challenge.go
@@ -3,7 +3,7 @@ package oidc
import (
"crypto/sha256"
- "github.com/caos/oidc/pkg/utils"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
)
const (
@@ -19,12 +19,12 @@ type CodeChallenge struct {
}
func NewSHACodeChallenge(code string) string {
- return utils.HashString(sha256.New(), code, false)
+ return crypto.HashString(sha256.New(), code, false)
}
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
if c == nil {
- return false //TODO: ?
+ return false
}
if c.Method == CodeChallengeMethodS256 {
codeVerifier = NewSHACodeChallenge(codeVerifier)
diff --git a/pkg/oidc/device_authorization.go b/pkg/oidc/device_authorization.go
new file mode 100644
index 0000000..a6417ba
--- /dev/null
+++ b/pkg/oidc/device_authorization.go
@@ -0,0 +1,51 @@
+package oidc
+
+import "encoding/json"
+
+// DeviceAuthorizationRequest implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.1,
+// 3.1 Device Authorization Request.
+type DeviceAuthorizationRequest struct {
+ Scopes SpaceDelimitedArray `schema:"scope"`
+ ClientID string `schema:"client_id"`
+}
+
+// DeviceAuthorizationResponse implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.2
+// 3.2. Device Authorization Response.
+type DeviceAuthorizationResponse struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURI string `json:"verification_uri"`
+ VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
+ ExpiresIn int `json:"expires_in"`
+ Interval int `json:"interval,omitempty"`
+}
+
+func (resp *DeviceAuthorizationResponse) UnmarshalJSON(data []byte) error {
+ type Alias DeviceAuthorizationResponse
+ aux := &struct {
+ // workaround misspelling of verification_uri
+ // https://stackoverflow.com/q/76696956/5690223
+ // https://developers.google.com/identity/protocols/oauth2/limited-input-device?hl=fr#success-response
+ VerificationURL string `json:"verification_url"`
+ *Alias
+ }{
+ Alias: (*Alias)(resp),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ if resp.VerificationURI == "" {
+ resp.VerificationURI = aux.VerificationURL
+ }
+ return nil
+}
+
+// DeviceAccessTokenRequest implements
+// https://www.rfc-editor.org/rfc/rfc8628#section-3.4,
+// Device Access Token Request.
+type DeviceAccessTokenRequest struct {
+ GrantType GrantType `json:"grant_type" schema:"grant_type"`
+ DeviceCode string `json:"device_code" schema:"device_code"`
+}
diff --git a/pkg/oidc/device_authorization_test.go b/pkg/oidc/device_authorization_test.go
new file mode 100644
index 0000000..c4c6637
--- /dev/null
+++ b/pkg/oidc/device_authorization_test.go
@@ -0,0 +1,30 @@
+package oidc
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDeviceAuthorizationResponse_UnmarshalJSON(t *testing.T) {
+ jsonStr := `{
+ "device_code": "deviceCode",
+ "user_code": "userCode",
+ "verification_url": "http://example.com/verify",
+ "expires_in": 3600,
+ "interval": 5
+ }`
+
+ expected := &DeviceAuthorizationResponse{
+ DeviceCode: "deviceCode",
+ UserCode: "userCode",
+ VerificationURI: "http://example.com/verify",
+ ExpiresIn: 3600,
+ Interval: 5,
+ }
+
+ var resp DeviceAuthorizationResponse
+ err := resp.UnmarshalJSON([]byte(jsonStr))
+ assert.NoError(t, err)
+ assert.Equal(t, expected, &resp)
+}
diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go
index 5d2875e..62288d1 100644
--- a/pkg/oidc/discovery.go
+++ b/pkg/oidc/discovery.go
@@ -5,20 +5,165 @@ const (
)
type DiscoveryConfiguration struct {
- Issuer string `json:"issuer,omitempty"`
- AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
- TokenEndpoint string `json:"token_endpoint,omitempty"`
- IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
- UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
- EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
- CheckSessionIframe string `json:"check_session_iframe,omitempty"`
- JwksURI string `json:"jwks_uri,omitempty"`
- ScopesSupported []string `json:"scopes_supported,omitempty"`
- ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
- ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
- GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
- SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
- IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
- TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
- ClaimsSupported []string `json:"claims_supported,omitempty"`
+ // Issuer is the identifier of the OP and is used in the tokens as `iss` claim.
+ Issuer string `json:"issuer,omitempty"`
+
+ // AuthorizationEndpoint is the URL of the OAuth 2.0 Authorization Endpoint where all user interactive login start
+ AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
+
+ // TokenEndpoint is the URL of the OAuth 2.0 Token Endpoint where all tokens are issued, except when using Implicit Flow
+ TokenEndpoint string `json:"token_endpoint,omitempty"`
+
+ // IntrospectionEndpoint is the URL of the OAuth 2.0 Introspection Endpoint.
+ IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
+
+ // UserinfoEndpoint is the URL where an access_token can be used to retrieve the Userinfo.
+ UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
+
+ // RevocationEndpoint is the URL of the OAuth 2.0 Revocation Endpoint.
+ RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
+
+ // EndSessionEndpoint is a URL where the RP can perform a redirect to request that the End-User be logged out at the OP.
+ EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
+
+ DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
+
+ // CheckSessionIframe is a URL where the OP provides an iframe that support cross-origin communications for session state information with the RP Client.
+ CheckSessionIframe string `json:"check_session_iframe,omitempty"`
+
+ // JwksURI is the URL of the JSON Web Key Set. This site contains the signing keys that RPs can use to validate the signature.
+ // It may also contain the OP's encryption keys that RPs can use to encrypt request to the OP.
+ JwksURI string `json:"jwks_uri,omitempty"`
+
+ // RegistrationEndpoint is the URL for the Dynamic Client Registration.
+ RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
+
+ // ScopesSupported lists an array of supported scopes. This list must not include every supported scope by the OP.
+ ScopesSupported []string `json:"scopes_supported,omitempty"`
+
+ // ResponseTypesSupported contains a list of the OAuth 2.0 response_type values that the OP supports (code, id_token, token id_token, ...).
+ ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
+
+ // ResponseModesSupported contains a list of the OAuth 2.0 response_mode values that the OP supports. If omitted, the default value is ["query", "fragment"].
+ ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
+
+ // GrantTypesSupported contains a list of the OAuth 2.0 grant_type values that the OP supports. If omitted, the default value is ["authorization_code", "implicit"].
+ GrantTypesSupported []GrantType `json:"grant_types_supported,omitempty"`
+
+ // ACRValuesSupported contains a list of Authentication Context Class References that the OP supports.
+ ACRValuesSupported []string `json:"acr_values_supported,omitempty"`
+
+ // SubjectTypesSupported contains a list of Subject Identifier types that the OP supports (pairwise, public).
+ SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
+
+ // IDTokenSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for the ID Token.
+ IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
+
+ // IDTokenEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the ID Token.
+ IDTokenEncryptionAlgValuesSupported []string `json:"id_token_encryption_alg_values_supported,omitempty"`
+
+ // IDTokenEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the ID Token.
+ IDTokenEncryptionEncValuesSupported []string `json:"id_token_encryption_enc_values_supported,omitempty"`
+
+ // UserinfoSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for UserInfo Endpoint.
+ UserinfoSigningAlgValuesSupported []string `json:"userinfo_signing_alg_values_supported,omitempty"`
+
+ // UserinfoEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for the UserInfo Endpoint.
+ UserinfoEncryptionAlgValuesSupported []string `json:"userinfo_encryption_alg_values_supported,omitempty"`
+
+ // UserinfoEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for the UserInfo Endpoint.
+ UserinfoEncryptionEncValuesSupported []string `json:"userinfo_encryption_enc_values_supported,omitempty"`
+
+ // RequestObjectSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the OP for Request Objects.
+ // These algorithms are used both then the Request Object is passed by value (using the request parameter) and when it is passed by reference (using the request_uri parameter).
+ RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported,omitempty"`
+
+ // RequestObjectEncryptionAlgValuesSupported contains a list of JWE encryption algorithms (alg values) supported by the OP for Request Objects.
+ // These algorithms are used both when the Request Object is passed by value and by reference.
+ RequestObjectEncryptionAlgValuesSupported []string `json:"request_object_encryption_alg_values_supported,omitempty"`
+
+ // RequestObjectEncryptionEncValuesSupported contains a list of JWE encryption algorithms (enc values) supported by the OP for Request Objects.
+ // These algorithms are used both when the Request Object is passed by value and by reference.
+ RequestObjectEncryptionEncValuesSupported []string `json:"request_object_encryption_enc_values_supported,omitempty"`
+
+ // TokenEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Token Endpoint. If omitted, the default is client_secret_basic.
+ TokenEndpointAuthMethodsSupported []AuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"`
+
+ // TokenEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Token Endpoint
+ // for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
+ TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"`
+
+ // RevocationEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Revocation Endpoint. If omitted, the default is client_secret_basic.
+ RevocationEndpointAuthMethodsSupported []AuthMethod `json:"revocation_endpoint_auth_methods_supported,omitempty"`
+
+ // RevocationEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint
+ // for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
+ RevocationEndpointAuthSigningAlgValuesSupported []string `json:"revocation_endpoint_auth_signing_alg_values_supported,omitempty"`
+
+ // IntrospectionEndpointAuthMethodsSupported contains a list of Client Authentication methods supported by the Introspection Endpoint.
+ IntrospectionEndpointAuthMethodsSupported []AuthMethod `json:"introspection_endpoint_auth_methods_supported,omitempty"`
+
+ // IntrospectionEndpointAuthSigningAlgValuesSupported contains a list of JWS signing algorithms (alg values) supported by the Revocation Endpoint
+ // for the signature of the JWT used to authenticate the Client by private_key_jwt and client_secret_jwt.
+ IntrospectionEndpointAuthSigningAlgValuesSupported []string `json:"introspection_endpoint_auth_signing_alg_values_supported,omitempty"`
+
+ // DisplayValuesSupported contains a list of display parameter values that the OP supports (page, popup, touch, wap).
+ DisplayValuesSupported []Display `json:"display_values_supported,omitempty"`
+
+ // ClaimTypesSupported contains a list of Claim Types that the OP supports (normal, aggregated, distributed). If omitted, the default is normal Claims.
+ ClaimTypesSupported []string `json:"claim_types_supported,omitempty"`
+
+ // ClaimsSupported contains a list of Claim Names the OP may be able to supply values for. This list might not be exhaustive.
+ ClaimsSupported []string `json:"claims_supported,omitempty"`
+
+ // ClaimsParameterSupported specifies whether the OP supports use of the `claims` parameter. If omitted, the default is false.
+ ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"`
+
+ // CodeChallengeMethodsSupported contains a list of Proof Key for Code Exchange (PKCE) code challenge methods supported by the OP.
+ CodeChallengeMethodsSupported []CodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"`
+
+ // ServiceDocumentation is a URL where developers can get information about the OP and its usage.
+ ServiceDocumentation string `json:"service_documentation,omitempty"`
+
+ // ClaimsLocalesSupported contains a list of BCP47 language tag values that the OP supports for values of Claims returned.
+ ClaimsLocalesSupported Locales `json:"claims_locales_supported,omitempty"`
+
+ // UILocalesSupported contains a list of BCP47 language tag values that the OP supports for the user interface.
+ UILocalesSupported Locales `json:"ui_locales_supported,omitempty"`
+
+ // RequestParameterSupported specifies whether the OP supports use of the `request` parameter. If omitted, the default value is false.
+ RequestParameterSupported bool `json:"request_parameter_supported,omitempty"`
+
+ // RequestURIParameterSupported specifies whether the OP supports use of the `request_uri` parameter. If omitted, the default value is true. (therefore no omitempty)
+ RequestURIParameterSupported bool `json:"request_uri_parameter_supported"`
+
+ // RequireRequestURIRegistration specifies whether the OP requires any `request_uri` to be pre-registered using the request_uris registration parameter. If omitted, the default value is false.
+ RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"`
+
+ // OPPolicyURI is a URL the OP provides to the person registering the Client to read about the OP's requirements on how the RP can use the data provided by the OP.
+ OPPolicyURI string `json:"op_policy_uri,omitempty"`
+
+ // OPTermsOfServiceURI is a URL the OpenID Provider provides to the person registering the Client to read about OpenID Provider's terms of service.
+ OPTermsOfServiceURI string `json:"op_tos_uri,omitempty"`
+
+ // BackChannelLogoutSupported specifies whether the OP supports back-channel logout (https://openid.net/specs/openid-connect-backchannel-1_0.html),
+ // with true indicating support. If omitted, the default value is false.
+ BackChannelLogoutSupported bool `json:"backchannel_logout_supported,omitempty"`
+
+ // BackChannelLogoutSessionSupported specifies whether the OP can pass a sid (session ID) Claim in the Logout Token to identify the RP session with the OP.
+ // If supported, the sid Claim is also included in ID Tokens issued by the OP. If omitted, the default value is false.
+ BackChannelLogoutSessionSupported bool `json:"backchannel_logout_session_supported,omitempty"`
+}
+
+type AuthMethod string
+
+const (
+ AuthMethodBasic AuthMethod = "client_secret_basic"
+ AuthMethodPost AuthMethod = "client_secret_post"
+ AuthMethodNone AuthMethod = "none"
+ AuthMethodPrivateKeyJWT AuthMethod = "private_key_jwt"
+)
+
+var AllAuthMethods = []AuthMethod{
+ AuthMethodBasic, AuthMethodPost, AuthMethodNone, AuthMethodPrivateKeyJWT,
}
diff --git a/pkg/oidc/error.go b/pkg/oidc/error.go
new file mode 100644
index 0000000..d93cf44
--- /dev/null
+++ b/pkg/oidc/error.go
@@ -0,0 +1,256 @@
+package oidc
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+)
+
+type errorType string
+
+const (
+ InvalidRequest errorType = "invalid_request"
+ InvalidScope errorType = "invalid_scope"
+ InvalidClient errorType = "invalid_client"
+ InvalidGrant errorType = "invalid_grant"
+ UnauthorizedClient errorType = "unauthorized_client"
+ UnsupportedGrantType errorType = "unsupported_grant_type"
+ ServerError errorType = "server_error"
+ InteractionRequired errorType = "interaction_required"
+ LoginRequired errorType = "login_required"
+ RequestNotSupported errorType = "request_not_supported"
+
+ // Additional error codes as defined in
+ // https://www.rfc-editor.org/rfc/rfc8628#section-3.5
+ // Device Access Token Response
+ AuthorizationPending errorType = "authorization_pending"
+ SlowDown errorType = "slow_down"
+ AccessDenied errorType = "access_denied"
+ ExpiredToken errorType = "expired_token"
+
+ // InvalidTarget error is returned by Token Exchange if
+ // the requested target or audience is invalid.
+ // [RFC 8693, Section 2.2.2: Error Response](https://www.rfc-editor.org/rfc/rfc8693#section-2.2.2)
+ InvalidTarget errorType = "invalid_target"
+)
+
+var (
+ ErrInvalidRequest = func() *Error {
+ return &Error{
+ ErrorType: InvalidRequest,
+ }
+ }
+ ErrInvalidRequestRedirectURI = func() *Error {
+ return &Error{
+ ErrorType: InvalidRequest,
+ redirectDisabled: true,
+ }
+ }
+ ErrInvalidScope = func() *Error {
+ return &Error{
+ ErrorType: InvalidScope,
+ }
+ }
+ ErrInvalidClient = func() *Error {
+ return &Error{
+ ErrorType: InvalidClient,
+ }
+ }
+ ErrInvalidGrant = func() *Error {
+ return &Error{
+ ErrorType: InvalidGrant,
+ }
+ }
+ ErrUnauthorizedClient = func() *Error {
+ return &Error{
+ ErrorType: UnauthorizedClient,
+ }
+ }
+ ErrUnsupportedGrantType = func() *Error {
+ return &Error{
+ ErrorType: UnsupportedGrantType,
+ }
+ }
+ ErrServerError = func() *Error {
+ return &Error{
+ ErrorType: ServerError,
+ }
+ }
+ ErrInteractionRequired = func() *Error {
+ return &Error{
+ ErrorType: InteractionRequired,
+ }
+ }
+ ErrLoginRequired = func() *Error {
+ return &Error{
+ ErrorType: LoginRequired,
+ }
+ }
+ ErrRequestNotSupported = func() *Error {
+ return &Error{
+ ErrorType: RequestNotSupported,
+ }
+ }
+
+ // Device Access Token errors:
+ ErrAuthorizationPending = func() *Error {
+ return &Error{
+ ErrorType: AuthorizationPending,
+ Description: "The client SHOULD repeat the access token request to the token endpoint, after interval from device authorization response.",
+ }
+ }
+ ErrSlowDown = func() *Error {
+ return &Error{
+ ErrorType: SlowDown,
+ Description: "Polling should continue, but the interval MUST be increased by 5 seconds for this and all subsequent requests.",
+ }
+ }
+ ErrAccessDenied = func() *Error {
+ return &Error{
+ ErrorType: AccessDenied,
+ Description: "The authorization request was denied.",
+ }
+ }
+ ErrExpiredDeviceCode = func() *Error {
+ return &Error{
+ ErrorType: ExpiredToken,
+ Description: "The \"device_code\" has expired.",
+ }
+ }
+
+ // Token exchange error
+ ErrInvalidTarget = func() *Error {
+ return &Error{
+ ErrorType: InvalidTarget,
+ Description: "The requested audience or target is invalid.",
+ }
+ }
+)
+
+type Error struct {
+ Parent error `json:"-" schema:"-"`
+ ErrorType errorType `json:"error" schema:"error"`
+ Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
+ State string `json:"state,omitempty" schema:"state,omitempty"`
+ SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"`
+ redirectDisabled bool `schema:"-"`
+ returnParent bool `schema:"-"`
+}
+
+func (e *Error) MarshalJSON() ([]byte, error) {
+ m := struct {
+ Error errorType `json:"error"`
+ ErrorDescription string `json:"error_description,omitempty"`
+ State string `json:"state,omitempty"`
+ SessionState string `json:"session_state,omitempty"`
+ Parent string `json:"parent,omitempty"`
+ }{
+ Error: e.ErrorType,
+ ErrorDescription: e.Description,
+ State: e.State,
+ SessionState: e.SessionState,
+ }
+ if e.returnParent {
+ m.Parent = e.Parent.Error()
+ }
+ return json.Marshal(m)
+}
+
+func (e *Error) Error() string {
+ message := "ErrorType=" + string(e.ErrorType)
+ if e.Description != "" {
+ message += " Description=" + e.Description
+ }
+ if e.Parent != nil {
+ message += " Parent=" + e.Parent.Error()
+ }
+ return message
+}
+
+func (e *Error) Unwrap() error {
+ return e.Parent
+}
+
+func (e *Error) Is(target error) bool {
+ t, ok := target.(*Error)
+ if !ok {
+ return false
+ }
+ return e.ErrorType == t.ErrorType &&
+ (e.Description == t.Description || t.Description == "") &&
+ (e.State == t.State || t.State == "") &&
+ (e.SessionState == t.SessionState || t.SessionState == "")
+}
+
+func (e *Error) WithParent(err error) *Error {
+ e.Parent = err
+ return e
+}
+
+// WithReturnParentToClient allows returning the set parent error to the HTTP client.
+// Currently it only supports setting the parent inside JSON responses, not redirect URLs.
+// As Go errors don't unmarshal well, only the marshaller is implemented for the moment.
+//
+// Warning: parent errors may contain sensitive data or unwanted details about the server status.
+// Also, the `parent` field is not a standard error field and might confuse certain clients
+// that require fully compliant responses.
+func (e *Error) WithReturnParentToClient(b bool) *Error {
+ e.returnParent = b
+ return e
+}
+
+func (e *Error) WithDescription(desc string, args ...any) *Error {
+ e.Description = fmt.Sprintf(desc, args...)
+ return e
+}
+
+func (e *Error) IsRedirectDisabled() bool {
+ return e.redirectDisabled
+}
+
+// DefaultToServerError checks if the error is an Error
+// if not the provided error will be wrapped into a ServerError
+func DefaultToServerError(err error, description string) *Error {
+ oauth := new(Error)
+ if ok := errors.As(err, &oauth); !ok {
+ oauth.ErrorType = ServerError
+ oauth.Description = description
+ oauth.Parent = err
+ }
+ return oauth
+}
+
+func (e *Error) LogLevel() slog.Level {
+ level := slog.LevelWarn
+ if e.ErrorType == ServerError {
+ level = slog.LevelError
+ }
+ if e.ErrorType == AuthorizationPending {
+ level = slog.LevelInfo
+ }
+ return level
+}
+
+func (e *Error) LogValue() slog.Value {
+ attrs := make([]slog.Attr, 0, 5)
+ if e.Parent != nil {
+ attrs = append(attrs, slog.Any("parent", e.Parent))
+ }
+ if e.Description != "" {
+ attrs = append(attrs, slog.String("description", e.Description))
+ }
+ if e.ErrorType != "" {
+ attrs = append(attrs, slog.String("type", string(e.ErrorType)))
+ }
+ if e.State != "" {
+ attrs = append(attrs, slog.String("state", e.State))
+ }
+ if e.SessionState != "" {
+ attrs = append(attrs, slog.String("session_state", e.SessionState))
+ }
+ if e.redirectDisabled {
+ attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
+ }
+ return slog.GroupValue(attrs...)
+}
diff --git a/pkg/oidc/error_test.go b/pkg/oidc/error_test.go
new file mode 100644
index 0000000..40d30b1
--- /dev/null
+++ b/pkg/oidc/error_test.go
@@ -0,0 +1,192 @@
+package oidc
+
+import (
+ "encoding/json"
+ "errors"
+ "io"
+ "log/slog"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDefaultToServerError(t *testing.T) {
+ type args struct {
+ err error
+ description string
+ }
+ tests := []struct {
+ name string
+ args args
+ want *Error
+ }{
+ {
+ name: "default",
+ args: args{
+ err: io.ErrClosedPipe,
+ description: "oops",
+ },
+ want: &Error{
+ ErrorType: ServerError,
+ Description: "oops",
+ Parent: io.ErrClosedPipe,
+ },
+ },
+ {
+ name: "our Error",
+ args: args{
+ err: ErrAccessDenied(),
+ description: "oops",
+ },
+ want: &Error{
+ ErrorType: AccessDenied,
+ Description: "The authorization request was denied.",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := DefaultToServerError(tt.args.err, tt.args.description)
+ assert.ErrorIs(t, got, tt.want)
+ })
+ }
+}
+
+func TestError_LogLevel(t *testing.T) {
+ tests := []struct {
+ name string
+ err *Error
+ want slog.Level
+ }{
+ {
+ name: "server error",
+ err: ErrServerError(),
+ want: slog.LevelError,
+ },
+ {
+ name: "authorization pending",
+ err: ErrAuthorizationPending(),
+ want: slog.LevelInfo,
+ },
+ {
+ name: "some other error",
+ err: ErrAccessDenied(),
+ want: slog.LevelWarn,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.err.LogLevel()
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestError_LogValue(t *testing.T) {
+ type fields struct {
+ Parent error
+ ErrorType errorType
+ Description string
+ State string
+ redirectDisabled bool
+ }
+ tests := []struct {
+ name string
+ fields fields
+ want slog.Value
+ }{
+ {
+ name: "parent",
+ fields: fields{
+ Parent: io.EOF,
+ },
+ want: slog.GroupValue(slog.Any("parent", io.EOF)),
+ },
+ {
+ name: "description",
+ fields: fields{
+ Description: "oops",
+ },
+ want: slog.GroupValue(slog.String("description", "oops")),
+ },
+ {
+ name: "errorType",
+ fields: fields{
+ ErrorType: ExpiredToken,
+ },
+ want: slog.GroupValue(slog.String("type", string(ExpiredToken))),
+ },
+ {
+ name: "state",
+ fields: fields{
+ State: "123",
+ },
+ want: slog.GroupValue(slog.String("state", "123")),
+ },
+ {
+ name: "all fields",
+ fields: fields{
+ Parent: io.EOF,
+ Description: "oops",
+ ErrorType: ExpiredToken,
+ State: "123",
+ },
+ want: slog.GroupValue(
+ slog.Any("parent", io.EOF),
+ slog.String("description", "oops"),
+ slog.String("type", string(ExpiredToken)),
+ slog.String("state", "123"),
+ ),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ e := &Error{
+ Parent: tt.fields.Parent,
+ ErrorType: tt.fields.ErrorType,
+ Description: tt.fields.Description,
+ State: tt.fields.State,
+ redirectDisabled: tt.fields.redirectDisabled,
+ }
+ got := e.LogValue()
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestError_MarshalJSON(t *testing.T) {
+ tests := []struct {
+ name string
+ e *Error
+ want string
+ }{
+ {
+ name: "simple error",
+ e: ErrAccessDenied(),
+ want: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
+ },
+ {
+ name: "with description",
+ e: ErrAccessDenied().WithDescription("oops"),
+ want: `{"error":"access_denied","error_description":"oops"}`,
+ },
+ {
+ name: "with parent",
+ e: ErrServerError().WithParent(errors.New("oops")),
+ want: `{"error":"server_error"}`,
+ },
+ {
+ name: "with return parent",
+ e: ErrServerError().WithParent(errors.New("oops")).WithReturnParentToClient(true),
+ want: `{"error":"server_error","parent":"oops"}`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := json.Marshal(tt.e)
+ require.NoError(t, err)
+ assert.JSONEq(t, tt.want, string(got))
+ })
+ }
+}
diff --git a/pkg/oidc/grants/client_credentials.go b/pkg/oidc/grants/client_credentials.go
index 998dda1..8cb8425 100644
--- a/pkg/oidc/grants/client_credentials.go
+++ b/pkg/oidc/grants/client_credentials.go
@@ -13,8 +13,8 @@ type clientCredentialsGrant struct {
clientSecret string `schema:"client_secret"`
}
-//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
-//sneding client_id and client_secret as basic auth header
+// ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
+// sending client_id and client_secret as basic auth header
func ClientCredentialsGrantBasic(scopes ...string) *clientCredentialsGrantBasic {
return &clientCredentialsGrantBasic{
grantType: "client_credentials",
@@ -22,8 +22,8 @@ func ClientCredentialsGrantBasic(scopes ...string) *clientCredentialsGrantBasic
}
}
-//ClientCredentialsGrantValues creates an oauth2 `Client Credentials` Grant
-//sneding client_id and client_secret as form values
+// ClientCredentialsGrantValues creates an oauth2 `Client Credentials` Grant
+// sending client_id and client_secret as form values
func ClientCredentialsGrantValues(clientID, clientSecret string, scopes ...string) *clientCredentialsGrant {
return &clientCredentialsGrant{
clientCredentialsGrantBasic: ClientCredentialsGrantBasic(scopes...),
diff --git a/pkg/oidc/introspection.go b/pkg/oidc/introspection.go
new file mode 100644
index 0000000..1a200eb
--- /dev/null
+++ b/pkg/oidc/introspection.go
@@ -0,0 +1,79 @@
+package oidc
+
+import "github.com/muhlemmer/gu"
+
+type IntrospectionRequest struct {
+ Token string `schema:"token"`
+}
+
+type ClientAssertionParams struct {
+ ClientAssertion string `schema:"client_assertion"`
+ ClientAssertionType string `schema:"client_assertion_type"`
+}
+
+// IntrospectionResponse implements RFC 7662, section 2.2 and
+// OpenID Connect Core 1.0, section 5.1 (UserInfo).
+// https://www.rfc-editor.org/rfc/rfc7662.html#section-2.2.
+// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
+type IntrospectionResponse struct {
+ Active bool `json:"active"`
+ Scope SpaceDelimitedArray `json:"scope,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ TokenType string `json:"token_type,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ AuthTime Time `json:"auth_time,omitempty"`
+ NotBefore Time `json:"nbf,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+ Issuer string `json:"iss,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ Username string `json:"username,omitempty"`
+ Actor *ActorClaims `json:"act,omitempty"`
+ UserInfoProfile
+ UserInfoEmail
+ UserInfoPhone
+
+ Address *UserInfoAddress `json:"address,omitempty"`
+ Claims map[string]any `json:"-"`
+}
+
+// SetUserInfo copies all relevant fields from UserInfo
+// into the IntroSpectionResponse.
+func (i *IntrospectionResponse) SetUserInfo(u *UserInfo) {
+ i.Subject = u.Subject
+ i.Username = u.PreferredUsername
+ i.Address = gu.PtrCopy(u.Address)
+ i.UserInfoProfile = u.UserInfoProfile
+ i.UserInfoEmail = u.UserInfoEmail
+ i.UserInfoPhone = u.UserInfoPhone
+ if i.Claims == nil {
+ i.Claims = gu.MapCopy(u.Claims)
+ } else {
+ gu.MapMerge(u.Claims, i.Claims)
+ }
+}
+
+// GetAddress is a safe getter that takes
+// care of a possible nil value.
+func (i *IntrospectionResponse) GetAddress() *UserInfoAddress {
+ if i.Address == nil {
+ return new(UserInfoAddress)
+ }
+ return i.Address
+}
+
+// introspectionResponseAlias prevents loops on the JSON methods
+type introspectionResponseAlias IntrospectionResponse
+
+func (i *IntrospectionResponse) MarshalJSON() ([]byte, error) {
+ if i.Username == "" {
+ i.Username = i.PreferredUsername
+ }
+ return mergeAndMarshalClaims((*introspectionResponseAlias)(i), i.Claims)
+}
+
+func (i *IntrospectionResponse) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*introspectionResponseAlias)(i), &i.Claims)
+}
diff --git a/pkg/oidc/introspection_test.go b/pkg/oidc/introspection_test.go
new file mode 100644
index 0000000..60cf8a4
--- /dev/null
+++ b/pkg/oidc/introspection_test.go
@@ -0,0 +1,79 @@
+package oidc
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestIntrospectionResponse_SetUserInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ start *IntrospectionResponse
+ want *IntrospectionResponse
+ }{
+ {
+
+ name: "nil claims",
+ start: &IntrospectionResponse{},
+ want: &IntrospectionResponse{
+ Subject: userInfoData.Subject,
+ Username: userInfoData.PreferredUsername,
+ Address: userInfoData.Address,
+ UserInfoProfile: userInfoData.UserInfoProfile,
+ UserInfoEmail: userInfoData.UserInfoEmail,
+ UserInfoPhone: userInfoData.UserInfoPhone,
+ Claims: gu.MapCopy(userInfoData.Claims),
+ },
+ },
+ {
+
+ name: "merge claims",
+ start: &IntrospectionResponse{
+ Claims: map[string]any{
+ "hello": "world",
+ },
+ },
+ want: &IntrospectionResponse{
+ Subject: userInfoData.Subject,
+ Username: userInfoData.PreferredUsername,
+ Address: userInfoData.Address,
+ UserInfoProfile: userInfoData.UserInfoProfile,
+ UserInfoEmail: userInfoData.UserInfoEmail,
+ UserInfoPhone: userInfoData.UserInfoPhone,
+ Claims: map[string]any{
+ "foo": "bar",
+ "hello": "world",
+ },
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.start.SetUserInfo(userInfoData)
+ assert.Equal(t, tt.want, tt.start)
+ })
+ }
+}
+
+func TestIntrospectionResponse_GetAddress(t *testing.T) {
+ // nil address
+ i := new(IntrospectionResponse)
+ assert.Equal(t, &UserInfoAddress{}, i.GetAddress())
+
+ i.Address = &UserInfoAddress{PostalCode: "1234"}
+ assert.Equal(t, i.Address, i.GetAddress())
+}
+
+func TestIntrospectionResponse_MarshalJSON(t *testing.T) {
+ got, err := json.Marshal(&IntrospectionResponse{
+ UserInfoProfile: UserInfoProfile{
+ PreferredUsername: "muhlemmer",
+ },
+ })
+ require.NoError(t, err)
+ assert.Equal(t, string(got), `{"active":false,"username":"muhlemmer","preferred_username":"muhlemmer"}`)
+}
diff --git a/pkg/oidc/jwt_profile.go b/pkg/oidc/jwt_profile.go
new file mode 100644
index 0000000..66fa3aa
--- /dev/null
+++ b/pkg/oidc/jwt_profile.go
@@ -0,0 +1,18 @@
+package oidc
+
+type JWTProfileGrantRequest struct {
+ Assertion string `schema:"assertion"`
+ Scope SpaceDelimitedArray `schema:"scope"`
+ GrantType GrantType `schema:"grant_type"`
+}
+
+// NewJWTProfileGrantRequest creates an oauth2 `JSON Web Token (JWT) Profile` Grant
+//`urn:ietf:params:oauth:grant-type:jwt-bearer`
+// sending a self-signed jwt as assertion
+func NewJWTProfileGrantRequest(assertion string, scopes ...string) *JWTProfileGrantRequest {
+ return &JWTProfileGrantRequest{
+ GrantType: GrantTypeBearer,
+ Assertion: assertion,
+ Scope: scopes,
+ }
+}
diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go
index f9bed2f..a8b89b0 100644
--- a/pkg/oidc/keyset.go
+++ b/pkg/oidc/keyset.go
@@ -2,21 +2,108 @@ package oidc
import (
"context"
+ "crypto/ecdsa"
+ "crypto/ed25519"
+ "crypto/rsa"
+ "errors"
+ "strings"
- "gopkg.in/square/go-jose.v2"
+ jose "github.com/go-jose/go-jose/v4"
)
-// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
-// of JSON web tokens. This is expected to be backed by a remote key set through
-// provider metadata discovery or an in-memory set of keys delivered out-of-band.
+const (
+ KeyUseSignature = "sig"
+)
+
+var (
+ ErrKeyMultiple = errors.New("multiple possible keys match")
+ ErrKeyNone = errors.New("no possible keys matches")
+)
+
+// KeySet represents a set of JSON Web Keys
+// - remotely fetch via discovery and jwks_uri -> `remoteKeySet`
+// - held by the OP itself in storage -> `openIDKeySet`
+// - dynamically aggregated by request for OAuth JWT Profile Assertion -> `jwtProfileKeySet`
type KeySet interface {
- // VerifySignature parses the JSON web token, verifies the signature, and returns
- // the raw payload. Header and claim fields are validated by other parts of the
- // package. For example, the KeySet does not need to check values such as signature
- // algorithm, issuer, and audience since the IDTokenVerifier validates these values
- // independently.
- //
- // If VerifySignature makes HTTP requests to verify the token, it's expected to
- // use any HTTP client associated with the context through ClientContext.
+ // VerifySignature verifies the signature with the given keyset and returns the raw payload
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
}
+
+// GetKeyIDAndAlg returns the `kid` and `alg` claim from the JWS header
+func GetKeyIDAndAlg(jws *jose.JSONWebSignature) (string, string) {
+ keyID := ""
+ alg := ""
+ for _, sig := range jws.Signatures {
+ keyID = sig.Header.KeyID
+ alg = sig.Header.Algorithm
+ break
+ }
+ return keyID, alg
+}
+
+// FindKey searches the given JSON Web Keys for the requested key ID, usage and key type
+//
+// will return the key immediately if matches exact (id, usage, type)
+//
+// will return false none or multiple match
+//
+// deprecated: use FindMatchingKey which will return an error (more specific) instead of just a bool
+// moved implementation already to FindMatchingKey
+func FindKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (jose.JSONWebKey, bool) {
+ key, err := FindMatchingKey(keyID, use, expectedAlg, keys...)
+ return key, err == nil
+}
+
+// FindMatchingKey searches the given JSON Web Keys for the requested key ID, usage and alg type
+//
+// will return the key immediately if matches exact (id, usage, type)
+//
+// will return a specific error if none (ErrKeyNone) or multiple (ErrKeyMultiple) match
+func FindMatchingKey(keyID, use, expectedAlg string, keys ...jose.JSONWebKey) (key jose.JSONWebKey, err error) {
+ var validKeys []jose.JSONWebKey
+ for _, k := range keys {
+ // ignore all keys with wrong use (let empty use of published key pass)
+ if k.Use != use && k.Use != "" {
+ continue
+ }
+ // ignore all keys with wrong algorithm type
+ if !algToKeyType(k.Key, expectedAlg) {
+ continue
+ }
+ // if we get here, use and alg match, so an equal (not empty) keyID is an exact match
+ if k.KeyID == keyID && keyID != "" {
+ return k, nil
+ }
+ // keyIDs did not match or at least one was empty (if later, then it could be a match)
+ if k.KeyID == "" || keyID == "" {
+ validKeys = append(validKeys, k)
+ }
+ }
+ // if we get here, no match was possible at all (use / alg) or no exact match due to
+ // the signed JWT and / or the published keys didn't have a kid
+ // if later applies and only one key could be found, we'll return it
+ // otherwise a corresponding error will be thrown
+ if len(validKeys) == 1 {
+ return validKeys[0], nil
+ }
+ if len(validKeys) > 1 {
+ return key, ErrKeyMultiple
+ }
+ return key, ErrKeyNone
+}
+
+func algToKeyType(key any, alg string) bool {
+ if strings.HasPrefix(alg, "RS") || strings.HasPrefix(alg, "PS") {
+ _, ok := key.(*rsa.PublicKey)
+ return ok
+ }
+ if strings.HasPrefix(alg, "ES") {
+ _, ok := key.(*ecdsa.PublicKey)
+ return ok
+ }
+ if alg == string(jose.EdDSA) {
+ _, ok := key.(ed25519.PublicKey)
+ return ok
+ }
+ return false
+}
diff --git a/pkg/oidc/keyset_test.go b/pkg/oidc/keyset_test.go
new file mode 100644
index 0000000..e01074e
--- /dev/null
+++ b/pkg/oidc/keyset_test.go
@@ -0,0 +1,429 @@
+package oidc
+
+import (
+ "crypto/ecdsa"
+ "crypto/rsa"
+ "errors"
+ "reflect"
+ "testing"
+
+ jose "github.com/go-jose/go-jose/v4"
+)
+
+func TestFindKey(t *testing.T) {
+ type args struct {
+ keyID string
+ use string
+ expectedAlg string
+ keys []jose.JSONWebKey
+ }
+ type res struct {
+ key jose.JSONWebKey
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "no keys, ErrKeyNone",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: nil,
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyNone,
+ },
+ },
+ {
+ "single key enc, ErrKeyNone",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "enc",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyNone,
+ },
+ },
+ {
+ "single key wrong algorithm, ErrKeyNone",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PrivateKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyNone,
+ },
+ },
+ {
+ "single key no kid, no jwt kid, match",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "single key kid, jwt no kid, match",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ KeyID: "id",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ KeyID: "id",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "single key no kid, jwt with kid, match",
+ args{
+ keyID: "id",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "single key no use, jwt with kid, match",
+ args{
+ keyID: "id",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ KeyID: "id",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ KeyID: "id",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "single key wrong kid, ErrKeyNone",
+ args{
+ keyID: "id",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ KeyID: "id2",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyNone,
+ },
+ },
+ {
+ "multiple keys no kid, jwt no kid, ErrKeyMultiple",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyMultiple,
+ },
+ },
+ {
+ "multiple keys with kid, jwt no kid, ErrKeyMultiple",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "sig",
+ KeyID: "id2",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyMultiple,
+ },
+ },
+ {
+ "multiple keys, single sig key, jwt no kid, match",
+ args{
+ keyID: "",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "enc",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "multiple keys no kid, jwt with kid, ErrKeyMultiple",
+ args{
+ keyID: "id",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyMultiple,
+ },
+ },
+ {
+ "multiple keys with kid, jwt with kid, match",
+ args{
+ keyID: "id1",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "sig",
+ KeyID: "id2",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "multiple keys, single sig key, jwt with kid, match",
+ args{
+ keyID: "id1",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Use: "enc",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Use: "sig",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "multiple keys, no use, jwt with kid, match",
+ args{
+ keyID: "id1",
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ KeyID: "id2",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ {
+ "multiple keys, no use, jwt without kid, ErrKeyMultiple",
+ args{
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keys: []jose.JSONWebKey{
+ {
+ KeyID: "id1",
+ Key: &rsa.PublicKey{},
+ },
+ {
+ KeyID: "id2",
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyMultiple,
+ },
+ },
+ {
+ "multiple keys, no use or id, jwt with kid, ErrKeyMultiple",
+ args{
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keyID: "id1",
+ keys: []jose.JSONWebKey{
+ {
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Key: &rsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{},
+ err: ErrKeyMultiple,
+ },
+ },
+ {
+ "multiple keys (only one matching alg), jwt with kid, match",
+ args{
+ use: KeyUseSignature,
+ expectedAlg: "RS256",
+ keyID: "id1",
+ keys: []jose.JSONWebKey{
+ {
+ Key: &rsa.PublicKey{},
+ },
+ {
+ Key: &ecdsa.PublicKey{},
+ },
+ },
+ },
+ res{
+ key: jose.JSONWebKey{
+ Key: &rsa.PublicKey{},
+ },
+ err: nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := FindMatchingKey(tt.args.keyID, tt.args.use, tt.args.expectedAlg, tt.args.keys...)
+ if (tt.res.err != nil && !errors.Is(err, tt.res.err)) || (tt.res.err == nil && err != nil) {
+ t.Errorf("FindKey() error, got = %v, want = %v", err, tt.res.err)
+ }
+ if !reflect.DeepEqual(got, tt.res.key) {
+ t.Errorf("FindKey() got = %v, want %v", got, tt.res.key)
+ }
+ })
+ }
+}
diff --git a/pkg/oidc/regression_assert_test.go b/pkg/oidc/regression_assert_test.go
new file mode 100644
index 0000000..dd9f5ad
--- /dev/null
+++ b/pkg/oidc/regression_assert_test.go
@@ -0,0 +1,53 @@
+//go:build !create_regression_data
+
+package oidc
+
+import (
+ "encoding/json"
+ "io"
+ "os"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// Test_assert_regression verifies current output from
+// json.Marshal to stored regression data.
+// These tests are only ran when the create_regression_data
+// tag is NOT set.
+func Test_assert_regression(t *testing.T) {
+ buf := new(strings.Builder)
+
+ for _, obj := range regressionData {
+ name := jsonFilename(obj)
+ t.Run(name, func(t *testing.T) {
+ file, err := os.Open(name)
+ require.NoError(t, err)
+ defer file.Close()
+
+ _, err = io.Copy(buf, file)
+ require.NoError(t, err)
+ want := buf.String()
+ buf.Reset()
+
+ encodeJSON(t, buf, obj)
+ first := buf.String()
+ buf.Reset()
+
+ assert.JSONEq(t, want, first)
+
+ target := reflect.New(reflect.TypeOf(obj).Elem()).Interface()
+
+ require.NoError(t,
+ json.Unmarshal([]byte(first), target),
+ )
+ second, err := json.Marshal(target)
+ require.NoError(t, err)
+
+ assert.JSONEq(t, want, string(second))
+ })
+ }
+}
diff --git a/pkg/oidc/regression_create_test.go b/pkg/oidc/regression_create_test.go
new file mode 100644
index 0000000..809fe60
--- /dev/null
+++ b/pkg/oidc/regression_create_test.go
@@ -0,0 +1,24 @@
+//go:build create_regression_data
+
+package oidc
+
+import (
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// Test_create_regression generates the regression data.
+// It is excluded from regular testing, unless
+// called with the create_regression_data tag:
+// go test -tags="create_regression_data" ./pkg/oidc
+func Test_create_regression(t *testing.T) {
+ for _, obj := range regressionData {
+ file, err := os.Create(jsonFilename(obj))
+ require.NoError(t, err)
+ defer file.Close()
+
+ encodeJSON(t, file, obj)
+ }
+}
diff --git a/pkg/oidc/regression_data/oidc.AccessTokenClaims.json b/pkg/oidc/regression_data/oidc.AccessTokenClaims.json
new file mode 100644
index 0000000..b63bf30
--- /dev/null
+++ b/pkg/oidc/regression_data/oidc.AccessTokenClaims.json
@@ -0,0 +1,23 @@
+{
+ "iss": "zitadel",
+ "sub": "hello@me.com",
+ "aud": [
+ "foo",
+ "bar"
+ ],
+ "jti": "900",
+ "azp": "just@me.com",
+ "nonce": "6969",
+ "acr": "something",
+ "amr": [
+ "some",
+ "methods"
+ ],
+ "scope": "email phone",
+ "client_id": "777",
+ "exp": 12345,
+ "iat": 12000,
+ "nbf": 12000,
+ "auth_time": 12000,
+ "foo": "bar"
+}
diff --git a/pkg/oidc/regression_data/oidc.IDTokenClaims.json b/pkg/oidc/regression_data/oidc.IDTokenClaims.json
new file mode 100644
index 0000000..af503fb
--- /dev/null
+++ b/pkg/oidc/regression_data/oidc.IDTokenClaims.json
@@ -0,0 +1,51 @@
+{
+ "iss": "zitadel",
+ "aud": [
+ "foo",
+ "bar"
+ ],
+ "jti": "900",
+ "azp": "just@me.com",
+ "nonce": "6969",
+ "at_hash": "acthashhash",
+ "c_hash": "hashhash",
+ "acr": "something",
+ "amr": [
+ "some",
+ "methods"
+ ],
+ "sid": "666",
+ "client_id": "777",
+ "exp": 12345,
+ "iat": 12000,
+ "nbf": 12000,
+ "auth_time": 12000,
+ "address": {
+ "country": "Moon",
+ "formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
+ "locality": "Smallvile",
+ "postal_code": "666-666",
+ "region": "Outer space",
+ "street_address": "Sesame street 666"
+ },
+ "birthdate": "1st of April",
+ "email": "tim@zitadel.com",
+ "email_verified": true,
+ "family_name": "MÃļhlmann",
+ "foo": "bar",
+ "gender": "male",
+ "given_name": "Tim",
+ "locale": "nl",
+ "middle_name": "Danger",
+ "name": "Tim MÃļhlmann",
+ "nickname": "muhlemmer",
+ "phone_number": "+1234567890",
+ "phone_number_verified": true,
+ "picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
+ "preferred_username": "muhlemmer",
+ "profile": "https://github.com/muhlemmer",
+ "sub": "hello@me.com",
+ "updated_at": 1,
+ "website": "https://zitadel.com",
+ "zoneinfo": "Europe/Amsterdam"
+}
diff --git a/pkg/oidc/regression_data/oidc.IntrospectionResponse.json b/pkg/oidc/regression_data/oidc.IntrospectionResponse.json
new file mode 100644
index 0000000..e0c21a2
--- /dev/null
+++ b/pkg/oidc/regression_data/oidc.IntrospectionResponse.json
@@ -0,0 +1,44 @@
+{
+ "active": true,
+ "address": {
+ "country": "Moon",
+ "formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
+ "locality": "Smallvile",
+ "postal_code": "666-666",
+ "region": "Outer space",
+ "street_address": "Sesame street 666"
+ },
+ "aud": [
+ "foo",
+ "bar"
+ ],
+ "birthdate": "1st of April",
+ "client_id": "777",
+ "email": "tim@zitadel.com",
+ "email_verified": true,
+ "exp": 12345,
+ "family_name": "MÃļhlmann",
+ "foo": "bar",
+ "gender": "male",
+ "given_name": "Tim",
+ "iat": 12000,
+ "iss": "zitadel",
+ "jti": "900",
+ "locale": "nl",
+ "middle_name": "Danger",
+ "name": "Tim MÃļhlmann",
+ "nbf": 12000,
+ "nickname": "muhlemmer",
+ "phone_number": "+1234567890",
+ "phone_number_verified": true,
+ "picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
+ "preferred_username": "muhlemmer",
+ "profile": "https://github.com/muhlemmer",
+ "scope": "email phone",
+ "sub": "hello@me.com",
+ "token_type": "idtoken",
+ "updated_at": 1,
+ "username": "muhlemmer",
+ "website": "https://zitadel.com",
+ "zoneinfo": "Europe/Amsterdam"
+}
diff --git a/pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json b/pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json
new file mode 100644
index 0000000..4ece780
--- /dev/null
+++ b/pkg/oidc/regression_data/oidc.JWTProfileAssertionClaims.json
@@ -0,0 +1,11 @@
+{
+ "aud": [
+ "foo",
+ "bar"
+ ],
+ "exp": 12345,
+ "foo": "bar",
+ "iat": 12000,
+ "iss": "zitadel",
+ "sub": "hello@me.com"
+}
diff --git a/pkg/oidc/regression_data/oidc.UserInfo.json b/pkg/oidc/regression_data/oidc.UserInfo.json
new file mode 100644
index 0000000..d7795e7
--- /dev/null
+++ b/pkg/oidc/regression_data/oidc.UserInfo.json
@@ -0,0 +1,30 @@
+{
+ "address": {
+ "country": "Moon",
+ "formatted": "Sesame street 666\n666-666, Smallvile\nMoon",
+ "locality": "Smallvile",
+ "postal_code": "666-666",
+ "region": "Outer space",
+ "street_address": "Sesame street 666"
+ },
+ "birthdate": "1st of April",
+ "email": "tim@zitadel.com",
+ "email_verified": true,
+ "family_name": "MÃļhlmann",
+ "foo": "bar",
+ "gender": "male",
+ "given_name": "Tim",
+ "locale": "nl",
+ "middle_name": "Danger",
+ "name": "Tim MÃļhlmann",
+ "nickname": "muhlemmer",
+ "phone_number": "+1234567890",
+ "phone_number_verified": true,
+ "picture": "https://avatars.githubusercontent.com/u/5411563?v=4",
+ "preferred_username": "muhlemmer",
+ "profile": "https://github.com/muhlemmer",
+ "sub": "hello@me.com",
+ "updated_at": 1,
+ "website": "https://zitadel.com",
+ "zoneinfo": "Europe/Amsterdam"
+}
diff --git a/pkg/oidc/regression_test.go b/pkg/oidc/regression_test.go
new file mode 100644
index 0000000..9cb3ff9
--- /dev/null
+++ b/pkg/oidc/regression_test.go
@@ -0,0 +1,40 @@
+package oidc
+
+// This file contains common functions and data for regression testing
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "path"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+const dataDir = "regression_data"
+
+// jsonFilename builds a filename for the regression testdata.
+// dataDir/.json
+func jsonFilename(obj any) string {
+ name := fmt.Sprintf("%T.json", obj)
+ return path.Join(
+ dataDir,
+ strings.TrimPrefix(name, "*"),
+ )
+}
+
+func encodeJSON(t *testing.T, w io.Writer, obj any) {
+ enc := json.NewEncoder(w)
+ enc.SetIndent("", "\t")
+ require.NoError(t, enc.Encode(obj))
+}
+
+var regressionData = []any{
+ accessTokenData,
+ idTokenData,
+ introspectionResponseData,
+ userInfoData,
+ jwtProfileAssertionData,
+}
diff --git a/pkg/oidc/revocation.go b/pkg/oidc/revocation.go
new file mode 100644
index 0000000..0a56c61
--- /dev/null
+++ b/pkg/oidc/revocation.go
@@ -0,0 +1,6 @@
+package oidc
+
+type RevocationRequest struct {
+ Token string `schema:"token"`
+ TokenTypeHint string `schema:"token_type_hint"`
+}
diff --git a/pkg/oidc/session.go b/pkg/oidc/session.go
index 418439e..39f9f08 100644
--- a/pkg/oidc/session.go
+++ b/pkg/oidc/session.go
@@ -1,7 +1,12 @@
package oidc
+// EndSessionRequest for the RP-Initiated Logout according to:
+// https://openid.net/specs/openid-connect-rpinitiated-1_0.html#RPLogout
type EndSessionRequest struct {
- IdTokenHint string `schema:"id_token_hint"`
- PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
- State string `schema:"state"`
+ IdTokenHint string `schema:"id_token_hint"`
+ LogoutHint string `schema:"logout_hint"`
+ ClientID string `schema:"client_id"`
+ PostLogoutRedirectURI string `schema:"post_logout_redirect_uri"`
+ State string `schema:"state"`
+ UILocales Locales `schema:"ui_locales"`
}
diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go
index 8f2afc2..4b43dcb 100644
--- a/pkg/oidc/token.go
+++ b/pkg/oidc/token.go
@@ -2,248 +2,421 @@ package oidc
import (
"encoding/json"
- "strings"
+ "os"
"time"
- "github.com/caos/oidc/pkg/utils"
+ jose "github.com/go-jose/go-jose/v4"
"golang.org/x/oauth2"
- "golang.org/x/text/language"
- "gopkg.in/square/go-jose.v2"
+
+ "github.com/muhlemmer/gu"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
)
-type Tokens struct {
+const (
+ // BearerToken defines the token_type `Bearer`, which is returned in a successful token response
+ BearerToken = "Bearer"
+
+ PrefixBearer = BearerToken + " "
+)
+
+type Tokens[C IDClaims] struct {
*oauth2.Token
- IDTokenClaims *IDTokenClaims
+ IDTokenClaims C
IDToken string
}
+// TokenClaims contains the base Claims used all tokens.
+// It implements OpenID Connect Core 1.0, section 2.
+// https://openid.net/specs/openid-connect-core-1_0.html#IDToken
+// And RFC 9068: JSON Web Token (JWT) Profile for OAuth 2.0 Access Tokens,
+// section 2.2. https://datatracker.ietf.org/doc/html/rfc9068#name-data-structure
+//
+// TokenClaims implements the Claims interface,
+// and can be used to extend larger claim types by embedding.
+type TokenClaims struct {
+ Issuer string `json:"iss,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ AuthTime Time `json:"auth_time,omitempty"`
+ NotBefore Time `json:"nbf,omitempty"`
+ Nonce string `json:"nonce,omitempty"`
+ AuthenticationContextClassReference string `json:"acr,omitempty"`
+ AuthenticationMethodsReferences []string `json:"amr,omitempty"`
+ AuthorizedParty string `json:"azp,omitempty"`
+ ClientID string `json:"client_id,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ Actor *ActorClaims `json:"act,omitempty"`
+
+ // Additional information set by this framework
+ SignatureAlg jose.SignatureAlgorithm `json:"-"`
+}
+
+func (c *TokenClaims) GetIssuer() string {
+ return c.Issuer
+}
+
+func (c *TokenClaims) GetSubject() string {
+ return c.Subject
+}
+
+func (c *TokenClaims) GetAudience() []string {
+ return c.Audience
+}
+
+func (c *TokenClaims) GetExpiration() time.Time {
+ return c.Expiration.AsTime()
+}
+
+func (c *TokenClaims) GetIssuedAt() time.Time {
+ return c.IssuedAt.AsTime()
+}
+
+func (c *TokenClaims) GetNonce() string {
+ return c.Nonce
+}
+
+func (c *TokenClaims) GetAuthTime() time.Time {
+ return c.AuthTime.AsTime()
+}
+
+func (c *TokenClaims) GetAuthorizedParty() string {
+ return c.AuthorizedParty
+}
+
+func (c *TokenClaims) GetSignatureAlgorithm() jose.SignatureAlgorithm {
+ return c.SignatureAlg
+}
+
+func (c *TokenClaims) GetAuthenticationContextClassReference() string {
+ return c.AuthenticationContextClassReference
+}
+
+func (c *TokenClaims) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {
+ c.SignatureAlg = algorithm
+}
+
type AccessTokenClaims struct {
- Issuer string
- Subject string
- Audiences []string
- Expiration time.Time
- IssuedAt time.Time
- NotBefore time.Time
- JWTID string
- AuthorizedParty string
- Nonce string
- AuthTime time.Time
- CodeHash string
- AuthenticationContextClassReference string
- AuthenticationMethodsReferences []string
- SessionID string
- Scopes []string
- ClientID string
- AccessTokenUseNumber int
+ TokenClaims
+ Scopes SpaceDelimitedArray `json:"scope,omitempty"`
+ Claims map[string]any `json:"-"`
}
+func NewAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string, skew time.Duration) *AccessTokenClaims {
+ now := time.Now().UTC().Add(-skew)
+ if len(audience) == 0 {
+ audience = append(audience, clientID)
+ }
+ return &AccessTokenClaims{
+ TokenClaims: TokenClaims{
+ Issuer: issuer,
+ Subject: subject,
+ Audience: audience,
+ Expiration: FromTime(expiration),
+ IssuedAt: FromTime(now),
+ NotBefore: FromTime(now),
+ ClientID: clientID,
+ JWTID: jwtid,
+ },
+ }
+}
+
+type atcAlias AccessTokenClaims
+
+func (a *AccessTokenClaims) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*atcAlias)(a), a.Claims)
+}
+
+func (a *AccessTokenClaims) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*atcAlias)(a), &a.Claims)
+}
+
+// IDTokenClaims extends TokenClaims by further implementing
+// OpenID Connect Core 1.0, sections 3.1.3.6 (Code flow),
+// 3.2.2.10 (implicit), 3.3.2.11 (Hybrid) and 5.1 (UserInfo).
+// https://openid.net/specs/openid-connect-core-1_0.html#toc
type IDTokenClaims struct {
- Issuer string
- Audiences []string
- Expiration time.Time
- NotBefore time.Time
- IssuedAt time.Time
- JWTID string
- UpdatedAt time.Time
- AuthorizedParty string
- Nonce string
- AuthTime time.Time
- AccessTokenHash string
- CodeHash string
- AuthenticationContextClassReference string
- AuthenticationMethodsReferences []string
- ClientID string
- Userinfo
-
- Signature jose.SignatureAlgorithm //TODO: ???
+ TokenClaims
+ NotBefore Time `json:"nbf,omitempty"`
+ AccessTokenHash string `json:"at_hash,omitempty"`
+ CodeHash string `json:"c_hash,omitempty"`
+ SessionID string `json:"sid,omitempty"`
+ UserInfoProfile
+ UserInfoEmail
+ UserInfoPhone
+ Address *UserInfoAddress `json:"address,omitempty"`
+ Claims map[string]any `json:"-"`
}
-type jsonToken struct {
- Issuer string `json:"iss,omitempty"`
- Subject string `json:"sub,omitempty"`
- Audiences []string `json:"aud,omitempty"`
- Expiration int64 `json:"exp,omitempty"`
- NotBefore int64 `json:"nbf,omitempty"`
- IssuedAt int64 `json:"iat,omitempty"`
- JWTID string `json:"jti,omitempty"`
- AuthorizedParty string `json:"azp,omitempty"`
- Nonce string `json:"nonce,omitempty"`
- AuthTime int64 `json:"auth_time,omitempty"`
- AccessTokenHash string `json:"at_hash,omitempty"`
- CodeHash string `json:"c_hash,omitempty"`
- AuthenticationContextClassReference string `json:"acr,omitempty"`
- AuthenticationMethodsReferences []string `json:"amr,omitempty"`
- SessionID string `json:"sid,omitempty"`
- Actor interface{} `json:"act,omitempty"` //TODO: impl
- Scopes string `json:"scope,omitempty"`
- ClientID string `json:"client_id,omitempty"`
- AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
- AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
- jsonUserinfo
+// GetAccessTokenHash implements the IDTokenClaims interface
+func (t *IDTokenClaims) GetAccessTokenHash() string {
+ return t.AccessTokenHash
}
-func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- SessionID: t.SessionID,
- Scopes: strings.Join(t.Scopes, " "),
- ClientID: t.ClientID,
- AccessTokenUseNumber: t.AccessTokenUseNumber,
- }
- return json.Marshal(j)
-}
-
-func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
- var j jsonToken
- if err := json.Unmarshal(b, &j); err != nil {
- return err
- }
- audience := j.Audiences
- if len(audience) == 1 {
- audience = strings.Split(audience[0], " ")
- }
- t.Issuer = j.Issuer
- t.Subject = j.Subject
- t.Audiences = audience
- t.Expiration = time.Unix(j.Expiration, 0).UTC()
- t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
- t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
- t.JWTID = j.JWTID
- t.AuthorizedParty = j.AuthorizedParty
- t.Nonce = j.Nonce
- t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
- t.CodeHash = j.CodeHash
- t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
- t.SessionID = j.SessionID
- t.Scopes = strings.Split(j.Scopes, " ")
- t.ClientID = j.ClientID
- t.AccessTokenUseNumber = j.AccessTokenUseNumber
- return nil
-}
-
-func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
- j := jsonToken{
- Issuer: t.Issuer,
- Subject: t.Subject,
- Audiences: t.Audiences,
- Expiration: timeToJSON(t.Expiration),
- NotBefore: timeToJSON(t.NotBefore),
- IssuedAt: timeToJSON(t.IssuedAt),
- JWTID: t.JWTID,
- AuthorizedParty: t.AuthorizedParty,
- Nonce: t.Nonce,
- AuthTime: timeToJSON(t.AuthTime),
- AccessTokenHash: t.AccessTokenHash,
- CodeHash: t.CodeHash,
- AuthenticationContextClassReference: t.AuthenticationContextClassReference,
- AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
- ClientID: t.ClientID,
- }
- j.setUserinfo(t.Userinfo)
- return json.Marshal(j)
-}
-
-func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
- var i jsonToken
- if err := json.Unmarshal(b, &i); err != nil {
- return err
- }
- audience := i.Audiences
- if len(audience) == 1 {
- audience = strings.Split(audience[0], " ")
- }
- t.Issuer = i.Issuer
+func (t *IDTokenClaims) SetUserInfo(i *UserInfo) {
t.Subject = i.Subject
- t.Audiences = audience
- t.Expiration = time.Unix(i.Expiration, 0).UTC()
- t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
- t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
- t.Nonce = i.Nonce
- t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
- t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
- t.AuthorizedParty = i.AuthorizedParty
- t.AccessTokenHash = i.AccessTokenHash
- t.CodeHash = i.CodeHash
- t.UserinfoProfile = i.UnmarshalUserinfoProfile()
- t.UserinfoEmail = i.UnmarshalUserinfoEmail()
- t.UserinfoPhone = i.UnmarshalUserinfoPhone()
- t.Address = i.UnmarshalUserinfoAddress()
- return nil
+ t.UserInfoProfile = i.UserInfoProfile
+ t.UserInfoEmail = i.UserInfoEmail
+ t.UserInfoPhone = i.UserInfoPhone
+ t.Address = i.Address
+ if t.Claims == nil {
+ t.Claims = make(map[string]any, len(t.Claims))
+ }
+ gu.MapMerge(i.Claims, t.Claims)
}
-func (j *jsonToken) UnmarshalUserinfoProfile() UserinfoProfile {
- locale, _ := language.Parse(j.Locale)
- return UserinfoProfile{
- Name: j.Name,
- GivenName: j.GivenName,
- FamilyName: j.FamilyName,
- MiddleName: j.MiddleName,
- Nickname: j.Nickname,
- Profile: j.Profile,
- Picture: j.Picture,
- Website: j.Website,
- Gender: Gender(j.Gender),
- Birthdate: j.Birthdate,
- Zoneinfo: j.Zoneinfo,
- Locale: locale,
- UpdatedAt: time.Unix(j.UpdatedAt, 0).UTC(),
- PreferredUsername: j.PreferredUsername,
+func (t *IDTokenClaims) GetUserInfo() *UserInfo {
+ return &UserInfo{
+ Subject: t.Subject,
+ UserInfoProfile: t.UserInfoProfile,
+ UserInfoEmail: t.UserInfoEmail,
+ UserInfoPhone: t.UserInfoPhone,
+ Address: t.Address,
+ Claims: gu.MapCopy(t.Claims),
}
}
-func (j *jsonToken) UnmarshalUserinfoEmail() UserinfoEmail {
- return UserinfoEmail{
- Email: j.Email,
- EmailVerified: j.EmailVerified,
+func NewIDTokenClaims(issuer, subject string, audience []string, expiration, authTime time.Time, nonce string, acr string, amr []string, clientID string, skew time.Duration) *IDTokenClaims {
+ audience = AppendClientIDToAudience(clientID, audience)
+ return &IDTokenClaims{
+ TokenClaims: TokenClaims{
+ Issuer: issuer,
+ Subject: subject,
+ Audience: audience,
+ Expiration: FromTime(expiration),
+ IssuedAt: FromTime(time.Now().Add(-skew)),
+ AuthTime: FromTime(authTime.Add(-skew)),
+ Nonce: nonce,
+ AuthenticationContextClassReference: acr,
+ AuthenticationMethodsReferences: amr,
+ AuthorizedParty: clientID,
+ ClientID: clientID,
+ },
}
}
-func (j *jsonToken) UnmarshalUserinfoPhone() UserinfoPhone {
- return UserinfoPhone{
- PhoneNumber: j.Phone,
- PhoneNumberVerified: j.PhoneVerified,
+type itcAlias IDTokenClaims
+
+func (i *IDTokenClaims) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*itcAlias)(i), i.Claims)
+}
+
+func (i *IDTokenClaims) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*itcAlias)(i), &i.Claims)
+}
+
+// ActorClaims provides the `act` claims used for impersonation or delegation Token Exchange.
+//
+// An actor can be nested in case an obtained token is used as actor token to obtain impersonation or delegation.
+// This allows creating a chain of actors.
+// See [RFC 8693, section 4.1](https://www.rfc-editor.org/rfc/rfc8693#name-act-actor-claim).
+type ActorClaims struct {
+ Actor *ActorClaims `json:"act,omitempty"`
+ Issuer string `json:"iss,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Claims map[string]any `json:"-"`
+}
+
+type acAlias ActorClaims
+
+func (c *ActorClaims) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*acAlias)(c), c.Claims)
+}
+
+func (c *ActorClaims) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*acAlias)(c), &c.Claims)
+}
+
+type AccessTokenResponse struct {
+ AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
+ TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
+ ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
+ IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
+ State string `json:"state,omitempty" schema:"state,omitempty"`
+ Scope SpaceDelimitedArray `json:"scope,omitempty" schema:"scope,omitempty"`
+}
+
+type JWTProfileAssertionClaims struct {
+ PrivateKeyID string `json:"-"`
+ PrivateKey []byte `json:"-"`
+ Issuer string `json:"iss"`
+ Subject string `json:"sub"`
+ Audience Audience `json:"aud"`
+ Expiration Time `json:"exp"`
+ IssuedAt Time `json:"iat"`
+
+ Claims map[string]any `json:"-"`
+}
+
+type jpaAlias JWTProfileAssertionClaims
+
+func (j *JWTProfileAssertionClaims) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*jpaAlias)(j), j.Claims)
+}
+
+func (j *JWTProfileAssertionClaims) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*jpaAlias)(j), &j.Claims)
+}
+
+func NewJWTProfileAssertionFromKeyJSON(filename string, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
+ data, err := os.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return NewJWTProfileAssertionFromFileData(data, audience, opts...)
+}
+
+func NewJWTProfileAssertionStringFromFileData(data []byte, audience []string, opts ...AssertionOption) (string, error) {
+ keyData := new(struct {
+ KeyID string `json:"keyId"`
+ Key string `json:"key"`
+ UserID string `json:"userId"`
+ })
+ err := json.Unmarshal(data, keyData)
+ if err != nil {
+ return "", err
+ }
+ return GenerateJWTProfileToken(NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...))
+}
+
+func JWTProfileDelegatedSubject(sub string) func(*JWTProfileAssertionClaims) {
+ return func(j *JWTProfileAssertionClaims) {
+ j.Subject = sub
}
}
-func (j *jsonToken) UnmarshalUserinfoAddress() *UserinfoAddress {
- if j.JsonUserinfoAddress == nil {
- return nil
+func JWTProfileCustomClaim(key string, value any) func(*JWTProfileAssertionClaims) {
+ return func(j *JWTProfileAssertionClaims) {
+ j.Claims[key] = value
}
- return &UserinfoAddress{
- Country: j.JsonUserinfoAddress.Country,
- Formatted: j.JsonUserinfoAddress.Formatted,
- Locality: j.JsonUserinfoAddress.Locality,
- PostalCode: j.JsonUserinfoAddress.PostalCode,
- Region: j.JsonUserinfoAddress.Region,
- StreetAddress: j.JsonUserinfoAddress.StreetAddress,
+}
+
+func NewJWTProfileAssertionFromFileData(data []byte, audience []string, opts ...AssertionOption) (*JWTProfileAssertionClaims, error) {
+ keyData := new(struct {
+ KeyID string `json:"keyId"`
+ Key string `json:"key"`
+ UserID string `json:"userId"`
+ })
+ err := json.Unmarshal(data, keyData)
+ if err != nil {
+ return nil, err
}
+ return NewJWTProfileAssertion(keyData.UserID, keyData.KeyID, audience, []byte(keyData.Key), opts...), nil
+}
+
+type AssertionOption func(*JWTProfileAssertionClaims)
+
+func NewJWTProfileAssertion(userID, keyID string, audience []string, key []byte, opts ...AssertionOption) *JWTProfileAssertionClaims {
+ j := &JWTProfileAssertionClaims{
+ PrivateKey: key,
+ PrivateKeyID: keyID,
+ Issuer: userID,
+ Subject: userID,
+ IssuedAt: FromTime(time.Now().UTC()),
+ Expiration: FromTime(time.Now().Add(1 * time.Hour).UTC()),
+ Audience: audience,
+ Claims: make(map[string]any),
+ }
+
+ for _, opt := range opts {
+ opt(j)
+ }
+
+ return j
}
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
- hash, err := utils.GetHashAlgorithm(sigAlgorithm)
+ hash, err := crypto.GetHashAlgorithm(sigAlgorithm)
if err != nil {
return "", err
}
- return utils.HashString(hash, claim, true), nil
+ return crypto.HashString(hash, claim, true), nil
}
-func timeToJSON(t time.Time) int64 {
- if t.IsZero() {
- return 0
+func AppendClientIDToAudience(clientID string, audience []string) []string {
+ for _, aud := range audience {
+ if aud == clientID {
+ return audience
+ }
+ }
+ return append(audience, clientID)
+}
+
+func GenerateJWTProfileToken(assertion *JWTProfileAssertionClaims) (string, error) {
+ privateKey, algorithm, err := crypto.BytesToPrivateKey(assertion.PrivateKey)
+ if err != nil {
+ return "", err
+ }
+ key := jose.SigningKey{
+ Algorithm: algorithm,
+ Key: &jose.JSONWebKey{Key: privateKey, KeyID: assertion.PrivateKeyID},
+ }
+ signer, err := jose.NewSigner(key, &jose.SignerOptions{})
+ if err != nil {
+ return "", err
+ }
+
+ marshalledAssertion, err := json.Marshal(assertion)
+ if err != nil {
+ return "", err
+ }
+ signedAssertion, err := signer.Sign(marshalledAssertion)
+ if err != nil {
+ return "", err
+ }
+ return signedAssertion.CompactSerialize()
+}
+
+type TokenExchangeResponse struct {
+ AccessToken string `json:"access_token"` // Can be access token or ID token
+ IssuedTokenType TokenType `json:"issued_token_type"`
+ TokenType string `json:"token_type"`
+ ExpiresIn uint64 `json:"expires_in,omitempty"`
+ Scopes SpaceDelimitedArray `json:"scope,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+
+ // IDToken field allows returning an additional ID token
+ // if the requested_token_type was Access Token and scope contained openid.
+ IDToken string `json:"id_token,omitempty"`
+}
+
+type LogoutTokenClaims struct {
+ Issuer string `json:"iss,omitempty"`
+ Subject string `json:"sub,omitempty"`
+ Audience Audience `json:"aud,omitempty"`
+ IssuedAt Time `json:"iat,omitempty"`
+ Expiration Time `json:"exp,omitempty"`
+ JWTID string `json:"jti,omitempty"`
+ Events map[string]any `json:"events,omitempty"`
+ SessionID string `json:"sid,omitempty"`
+ Claims map[string]any `json:"-"`
+}
+
+type ltcAlias LogoutTokenClaims
+
+func (i *LogoutTokenClaims) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*ltcAlias)(i), i.Claims)
+}
+
+func (i *LogoutTokenClaims) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*ltcAlias)(i), &i.Claims)
+}
+
+func NewLogoutTokenClaims(issuer, subject string, audience Audience, expiration time.Time, jwtID, sessionID string, skew time.Duration) *LogoutTokenClaims {
+ return &LogoutTokenClaims{
+ Issuer: issuer,
+ Subject: subject,
+ Audience: audience,
+ IssuedAt: FromTime(time.Now().Add(-skew)),
+ Expiration: FromTime(expiration),
+ JWTID: jwtID,
+ Events: map[string]any{
+ "http://schemas.openid.net/event/backchannel-logout": struct{}{},
+ },
+ SessionID: sessionID,
}
- return t.Unix()
}
diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go
new file mode 100644
index 0000000..dadb205
--- /dev/null
+++ b/pkg/oidc/token_request.go
@@ -0,0 +1,245 @@
+package oidc
+
+import (
+ "encoding/json"
+ "fmt"
+ "slices"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+)
+
+const (
+ // GrantTypeCode defines the grant_type `authorization_code` used for the Token Request in the Authorization Code Flow
+ GrantTypeCode GrantType = "authorization_code"
+
+ // GrantTypeRefreshToken defines the grant_type `refresh_token` used for the Token Request in the Refresh Token Flow
+ GrantTypeRefreshToken GrantType = "refresh_token"
+
+ // GrantTypeClientCredentials defines the grant_type `client_credentials` used for the Token Request in the Client Credentials Token Flow
+ GrantTypeClientCredentials GrantType = "client_credentials"
+
+ // GrantTypeBearer defines the grant_type `urn:ietf:params:oauth:grant-type:jwt-bearer` used for the JWT Authorization Grant
+ GrantTypeBearer GrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
+
+ // GrantTypeTokenExchange defines the grant_type `urn:ietf:params:oauth:grant-type:token-exchange` used for the OAuth Token Exchange Grant
+ GrantTypeTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
+
+ // GrantTypeImplicit defines the grant type `implicit` used for implicit flows that skip the generation and exchange of an Authorization Code
+ GrantTypeImplicit GrantType = "implicit"
+
+ // GrantTypeDeviceCode
+ GrantTypeDeviceCode GrantType = "urn:ietf:params:oauth:grant-type:device_code"
+
+ // ClientAssertionTypeJWTAssertion defines the client_assertion_type `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
+ // used for the OAuth JWT Profile Client Authentication
+ ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
+)
+
+var AllGrantTypes = []GrantType{
+ GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials,
+ GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit,
+ GrantTypeDeviceCode, ClientAssertionTypeJWTAssertion,
+}
+
+type GrantType string
+
+const (
+ AccessTokenType TokenType = "urn:ietf:params:oauth:token-type:access_token"
+ RefreshTokenType TokenType = "urn:ietf:params:oauth:token-type:refresh_token"
+ IDTokenType TokenType = "urn:ietf:params:oauth:token-type:id_token"
+ JWTTokenType TokenType = "urn:ietf:params:oauth:token-type:jwt"
+)
+
+var AllTokenTypes = []TokenType{
+ AccessTokenType, RefreshTokenType, IDTokenType, JWTTokenType,
+}
+
+type TokenType string
+
+func (t TokenType) IsSupported() bool {
+ return slices.Contains(AllTokenTypes, t)
+}
+
+type TokenRequest interface {
+ // GrantType GrantType `schema:"grant_type"`
+ GrantType() GrantType
+}
+
+type TokenRequestType GrantType
+
+type AccessTokenRequest struct {
+ Code string `schema:"code"`
+ RedirectURI string `schema:"redirect_uri"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret,omitempty"`
+ CodeVerifier string `schema:"code_verifier,omitempty"`
+ ClientAssertion string `schema:"client_assertion,omitempty"`
+ ClientAssertionType string `schema:"client_assertion_type,omitempty"`
+}
+
+func (a *AccessTokenRequest) GrantType() GrantType {
+ return GrantTypeCode
+}
+
+// SetClientID implements op.AuthenticatedTokenRequest
+func (a *AccessTokenRequest) SetClientID(clientID string) {
+ a.ClientID = clientID
+}
+
+// SetClientSecret implements op.AuthenticatedTokenRequest
+func (a *AccessTokenRequest) SetClientSecret(clientSecret string) {
+ a.ClientSecret = clientSecret
+}
+
+// RefreshTokenRequest is not useful for making refresh requests because the
+// grant_type is not included explicitly but rather implied.
+type RefreshTokenRequest struct {
+ RefreshToken string `schema:"refresh_token"`
+ Scopes SpaceDelimitedArray `schema:"scope"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"`
+ ClientAssertion string `schema:"client_assertion"`
+ ClientAssertionType string `schema:"client_assertion_type"`
+}
+
+func (a *RefreshTokenRequest) GrantType() GrantType {
+ return GrantTypeRefreshToken
+}
+
+// SetClientID implements op.AuthenticatedTokenRequest
+func (a *RefreshTokenRequest) SetClientID(clientID string) {
+ a.ClientID = clientID
+}
+
+// SetClientSecret implements op.AuthenticatedTokenRequest
+func (a *RefreshTokenRequest) SetClientSecret(clientSecret string) {
+ a.ClientSecret = clientSecret
+}
+
+type JWTTokenRequest struct {
+ Issuer string `json:"iss"`
+ Subject string `json:"sub"`
+ Scopes SpaceDelimitedArray `json:"-"`
+ Audience Audience `json:"aud"`
+ IssuedAt Time `json:"iat"`
+ ExpiresAt Time `json:"exp"`
+
+ private map[string]any
+}
+
+func (j *JWTTokenRequest) MarshalJSON() ([]byte, error) {
+ type Alias JWTTokenRequest
+ a := (*Alias)(j)
+
+ b, err := json.Marshal(a)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(j.private) == 0 {
+ return b, nil
+ }
+
+ err = json.Unmarshal(b, &j.private)
+ if err != nil {
+ return nil, fmt.Errorf("jws: invalid map of custom claims %v", j.private)
+ }
+
+ return json.Marshal(j.private)
+}
+
+func (j *JWTTokenRequest) UnmarshalJSON(data []byte) error {
+ type Alias JWTTokenRequest
+ a := (*Alias)(j)
+
+ err := json.Unmarshal(data, a)
+ if err != nil {
+ return err
+ }
+
+ err = json.Unmarshal(data, &j.private)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (j *JWTTokenRequest) GetCustomClaim(key string) any {
+ return j.private[key]
+}
+
+// GetIssuer implements the Claims interface
+func (j *JWTTokenRequest) GetIssuer() string {
+ return j.Issuer
+}
+
+// GetAudience implements the Claims and TokenRequest interfaces
+func (j *JWTTokenRequest) GetAudience() []string {
+ return j.Audience
+}
+
+// GetExpiration implements the Claims interface
+func (j *JWTTokenRequest) GetExpiration() time.Time {
+ return j.ExpiresAt.AsTime()
+}
+
+// GetIssuedAt implements the Claims interface
+func (j *JWTTokenRequest) GetIssuedAt() time.Time {
+ return j.IssuedAt.AsTime()
+}
+
+// GetNonce implements the Claims interface
+func (j *JWTTokenRequest) GetNonce() string {
+ return ""
+}
+
+// GetAuthenticationContextClassReference implements the Claims interface
+func (j *JWTTokenRequest) GetAuthenticationContextClassReference() string {
+ return ""
+}
+
+// GetAuthTime implements the Claims interface
+func (j *JWTTokenRequest) GetAuthTime() time.Time {
+ return time.Time{}
+}
+
+// GetAuthorizedParty implements the Claims interface
+func (j *JWTTokenRequest) GetAuthorizedParty() string {
+ return ""
+}
+
+// SetSignatureAlgorithm implements the Claims interface
+func (j *JWTTokenRequest) SetSignatureAlgorithm(_ jose.SignatureAlgorithm) {}
+
+// GetSubject implements the TokenRequest interface
+func (j *JWTTokenRequest) GetSubject() string {
+ return j.Subject
+}
+
+// GetScopes implements the TokenRequest interface
+func (j *JWTTokenRequest) GetScopes() []string {
+ return j.Scopes
+}
+
+type TokenExchangeRequest struct {
+ GrantType GrantType `schema:"grant_type"`
+ SubjectToken string `schema:"subject_token"`
+ SubjectTokenType TokenType `schema:"subject_token_type"`
+ ActorToken string `schema:"actor_token"`
+ ActorTokenType TokenType `schema:"actor_token_type"`
+ Resource []string `schema:"resource"`
+ Audience Audience `schema:"audience"`
+ Scopes SpaceDelimitedArray `schema:"scope"`
+ RequestedTokenType TokenType `schema:"requested_token_type"`
+}
+
+type ClientCredentialsRequest struct {
+ GrantType GrantType `schema:"grant_type,omitempty"`
+ Scope SpaceDelimitedArray `schema:"scope"`
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"`
+ ClientAssertion string `schema:"client_assertion"`
+ ClientAssertionType string `schema:"client_assertion_type"`
+}
diff --git a/pkg/oidc/token_test.go b/pkg/oidc/token_test.go
new file mode 100644
index 0000000..621cdbc
--- /dev/null
+++ b/pkg/oidc/token_test.go
@@ -0,0 +1,280 @@
+package oidc
+
+import (
+ "testing"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/text/language"
+)
+
+var (
+ tokenClaimsData = TokenClaims{
+ Issuer: "zitadel",
+ Subject: "hello@me.com",
+ Audience: Audience{"foo", "bar"},
+ Expiration: 12345,
+ IssuedAt: 12000,
+ JWTID: "900",
+ AuthorizedParty: "just@me.com",
+ Nonce: "6969",
+ AuthTime: 12000,
+ NotBefore: 12000,
+ AuthenticationContextClassReference: "something",
+ AuthenticationMethodsReferences: []string{"some", "methods"},
+ ClientID: "777",
+ SignatureAlg: jose.ES256,
+ }
+ accessTokenData = &AccessTokenClaims{
+ TokenClaims: tokenClaimsData,
+ Scopes: []string{"email", "phone"},
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+ idTokenData = &IDTokenClaims{
+ TokenClaims: tokenClaimsData,
+ NotBefore: 12000,
+ AccessTokenHash: "acthashhash",
+ CodeHash: "hashhash",
+ SessionID: "666",
+ UserInfoProfile: userInfoData.UserInfoProfile,
+ UserInfoEmail: userInfoData.UserInfoEmail,
+ UserInfoPhone: userInfoData.UserInfoPhone,
+ Address: userInfoData.Address,
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+ introspectionResponseData = &IntrospectionResponse{
+ Active: true,
+ Scope: SpaceDelimitedArray{"email", "phone"},
+ ClientID: "777",
+ TokenType: "idtoken",
+ Expiration: 12345,
+ IssuedAt: 12000,
+ NotBefore: 12000,
+ Subject: "hello@me.com",
+ Audience: Audience{"foo", "bar"},
+ Issuer: "zitadel",
+ JWTID: "900",
+ Username: "muhlemmer",
+ UserInfoProfile: userInfoData.UserInfoProfile,
+ UserInfoEmail: userInfoData.UserInfoEmail,
+ UserInfoPhone: userInfoData.UserInfoPhone,
+ Address: userInfoData.Address,
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+ userInfoData = &UserInfo{
+ Subject: "hello@me.com",
+ UserInfoProfile: UserInfoProfile{
+ Name: "Tim MÃļhlmann",
+ GivenName: "Tim",
+ FamilyName: "MÃļhlmann",
+ MiddleName: "Danger",
+ Nickname: "muhlemmer",
+ Profile: "https://github.com/muhlemmer",
+ Picture: "https://avatars.githubusercontent.com/u/5411563?v=4",
+ Website: "https://zitadel.com",
+ Gender: "male",
+ Birthdate: "1st of April",
+ Zoneinfo: "Europe/Amsterdam",
+ Locale: NewLocale(language.Dutch),
+ UpdatedAt: 1,
+ PreferredUsername: "muhlemmer",
+ },
+ UserInfoEmail: UserInfoEmail{
+ Email: "tim@zitadel.com",
+ EmailVerified: true,
+ },
+ UserInfoPhone: UserInfoPhone{
+ PhoneNumber: "+1234567890",
+ PhoneNumberVerified: true,
+ },
+ Address: &UserInfoAddress{
+ Formatted: "Sesame street 666\n666-666, Smallvile\nMoon",
+ StreetAddress: "Sesame street 666",
+ Locality: "Smallvile",
+ Region: "Outer space",
+ PostalCode: "666-666",
+ Country: "Moon",
+ },
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+ jwtProfileAssertionData = &JWTProfileAssertionClaims{
+ PrivateKeyID: "8888",
+ PrivateKey: []byte("qwerty"),
+ Issuer: "zitadel",
+ Subject: "hello@me.com",
+ Audience: Audience{"foo", "bar"},
+ Expiration: 12345,
+ IssuedAt: 12000,
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+)
+
+func TestTokenClaims(t *testing.T) {
+ claims := tokenClaimsData
+
+ assert.Equal(t, claims.Issuer, tokenClaimsData.GetIssuer())
+ assert.Equal(t, claims.Subject, tokenClaimsData.GetSubject())
+ assert.Equal(t, []string(claims.Audience), tokenClaimsData.GetAudience())
+ assert.Equal(t, claims.Expiration.AsTime(), tokenClaimsData.GetExpiration())
+ assert.Equal(t, claims.IssuedAt.AsTime(), tokenClaimsData.GetIssuedAt())
+ assert.Equal(t, claims.Nonce, tokenClaimsData.GetNonce())
+ assert.Equal(t, claims.AuthTime.AsTime(), tokenClaimsData.GetAuthTime())
+ assert.Equal(t, claims.AuthorizedParty, tokenClaimsData.GetAuthorizedParty())
+ assert.Equal(t, claims.SignatureAlg, tokenClaimsData.GetSignatureAlgorithm())
+ assert.Equal(t, claims.AuthenticationContextClassReference, tokenClaimsData.GetAuthenticationContextClassReference())
+
+ claims.SetSignatureAlgorithm(jose.ES384)
+ assert.Equal(t, jose.ES384, claims.SignatureAlg)
+}
+
+func TestNewAccessTokenClaims(t *testing.T) {
+ want := &AccessTokenClaims{
+ TokenClaims: TokenClaims{
+ Issuer: "zitadel",
+ Subject: "hello@me.com",
+ Audience: Audience{"foo"},
+ Expiration: 12345,
+ ClientID: "foo",
+ JWTID: "900",
+ },
+ }
+
+ got := NewAccessTokenClaims(
+ want.Issuer, want.Subject, nil,
+ want.Expiration.AsTime(), want.JWTID, "foo", time.Second,
+ )
+
+ // test if the dynamic timestamps are around now,
+ // allowing for a delta of 1, just in case we flip on
+ // either side of a second boundry.
+ nowMinusSkew := NowTime() - 1
+ assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
+ assert.InDelta(t, int64(nowMinusSkew), int64(got.NotBefore), 1)
+
+ // Make equal not fail on dynamic timestamp
+ got.IssuedAt = 0
+ got.NotBefore = 0
+
+ assert.Equal(t, want, got)
+}
+
+func TestIDTokenClaims_GetAccessTokenHash(t *testing.T) {
+ assert.Equal(t, idTokenData.AccessTokenHash, idTokenData.GetAccessTokenHash())
+}
+
+func TestIDTokenClaims_SetUserInfo(t *testing.T) {
+ want := IDTokenClaims{
+ TokenClaims: TokenClaims{
+ Subject: userInfoData.Subject,
+ },
+ UserInfoProfile: userInfoData.UserInfoProfile,
+ UserInfoEmail: userInfoData.UserInfoEmail,
+ UserInfoPhone: userInfoData.UserInfoPhone,
+ Address: userInfoData.Address,
+ Claims: map[string]any{
+ "foo": "bar",
+ },
+ }
+
+ var got IDTokenClaims
+ got.SetUserInfo(userInfoData)
+
+ assert.Equal(t, want, got)
+}
+
+func TestNewIDTokenClaims(t *testing.T) {
+ want := &IDTokenClaims{
+ TokenClaims: TokenClaims{
+ Issuer: "zitadel",
+ Subject: "hello@me.com",
+ Audience: Audience{"foo", "just@me.com"},
+ Expiration: 12345,
+ AuthTime: 12000,
+ Nonce: "6969",
+ AuthenticationContextClassReference: "something",
+ AuthenticationMethodsReferences: []string{"some", "methods"},
+ AuthorizedParty: "just@me.com",
+ ClientID: "just@me.com",
+ },
+ }
+
+ got := NewIDTokenClaims(
+ want.Issuer, want.Subject, want.Audience,
+ want.Expiration.AsTime(),
+ want.AuthTime.AsTime().Add(time.Second),
+ want.Nonce, want.AuthenticationContextClassReference,
+ want.AuthenticationMethodsReferences, want.AuthorizedParty,
+ time.Second,
+ )
+
+ // test if the dynamic timestamp is around now,
+ // allowing for a delta of 1, just in case we flip on
+ // either side of a second boundry.
+ nowMinusSkew := NowTime() - 1
+ assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
+
+ // Make equal not fail on dynamic timestamp
+ got.IssuedAt = 0
+
+ assert.Equal(t, want, got)
+}
+
+func TestIDTokenClaims_GetUserInfo(t *testing.T) {
+ want := &UserInfo{
+ Subject: idTokenData.Subject,
+ UserInfoProfile: idTokenData.UserInfoProfile,
+ UserInfoEmail: idTokenData.UserInfoEmail,
+ UserInfoPhone: idTokenData.UserInfoPhone,
+ Address: idTokenData.Address,
+ Claims: idTokenData.Claims,
+ }
+ got := idTokenData.GetUserInfo()
+ assert.Equal(t, want, got)
+}
+
+func TestNewLogoutTokenClaims(t *testing.T) {
+ want := &LogoutTokenClaims{
+ Issuer: "zitadel",
+ Subject: "hello@me.com",
+ Audience: Audience{"foo", "just@me.com"},
+ Expiration: 12345,
+ JWTID: "jwtID",
+ Events: map[string]any{
+ "http://schemas.openid.net/event/backchannel-logout": struct{}{},
+ },
+ SessionID: "sessionID",
+ Claims: nil,
+ }
+
+ got := NewLogoutTokenClaims(
+ want.Issuer,
+ want.Subject,
+ want.Audience,
+ want.Expiration.AsTime(),
+ want.JWTID,
+ want.SessionID,
+ 1*time.Second,
+ )
+
+ // test if the dynamic timestamp is around now,
+ // allowing for a delta of 1, just in case we flip on
+ // either side of a second boundry.
+ nowMinusSkew := NowTime() - 1
+ assert.InDelta(t, int64(nowMinusSkew), int64(got.IssuedAt), 1)
+
+ // Make equal not fail on dynamic timestamp
+ got.IssuedAt = 0
+
+ assert.Equal(t, want, got)
+}
diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go
new file mode 100644
index 0000000..5d063b1
--- /dev/null
+++ b/pkg/oidc/types.go
@@ -0,0 +1,313 @@
+package oidc
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/muhlemmer/gu"
+ "github.com/zitadel/schema"
+ "golang.org/x/text/language"
+)
+
+type Audience []string
+
+func (a *Audience) UnmarshalJSON(text []byte) error {
+ var i any
+ err := json.Unmarshal(text, &i)
+ if err != nil {
+ return err
+ }
+ switch aud := i.(type) {
+ case []any:
+ *a = make([]string, len(aud))
+ for i, audience := range aud {
+ (*a)[i] = audience.(string)
+ }
+ case string:
+ *a = []string{aud}
+ }
+ return nil
+}
+
+func (a *Audience) MarshalJSON() ([]byte, error) {
+ len := len(*a)
+ if len > 1 {
+ return json.Marshal(*a)
+ } else if len == 1 {
+ return json.Marshal((*a)[0])
+ }
+
+ return nil, errors.New("aud is empty")
+}
+
+type Display string
+
+func (d *Display) UnmarshalText(text []byte) error {
+ display := Display(text)
+ switch display {
+ case DisplayPage, DisplayPopup, DisplayTouch, DisplayWAP:
+ *d = display
+ }
+ return nil
+}
+
+type Gender string
+
+type Locale struct {
+ tag language.Tag
+}
+
+func NewLocale(tag language.Tag) *Locale {
+ return &Locale{tag: tag}
+}
+
+func (l *Locale) Tag() language.Tag {
+ if l == nil {
+ return language.Und
+ }
+
+ return l.tag
+}
+
+func (l *Locale) String() string {
+ return l.Tag().String()
+}
+
+func (l *Locale) MarshalJSON() ([]byte, error) {
+ tag := l.Tag()
+ if tag.IsRoot() {
+ return []byte("null"), nil
+ }
+
+ return json.Marshal(tag)
+}
+
+// UnmarshalJSON implements json.Unmarshaler.
+// When [language.ValueError] is encountered, the containing tag will be set
+// to an empty value (language "und") and no error will be returned.
+// This state can be checked with the `l.Tag().IsRoot()` method.
+func (l *Locale) UnmarshalJSON(data []byte) error {
+ if len(data) == 0 || string(data) == "\"\"" {
+ return nil
+ }
+ err := json.Unmarshal(data, &l.tag)
+ if err == nil {
+ return nil
+ }
+
+ // catch "well-formed but unknown" errors
+ var target language.ValueError
+ if errors.As(err, &target) {
+ l.tag = language.Tag{}
+ return nil
+ }
+ return err
+}
+
+type Locales []language.Tag
+
+// ParseLocales parses a slice of strings into Locales.
+// If an entry causes a parse error or is undefined,
+// it is ignored and not set to Locales.
+func ParseLocales(locales []string) Locales {
+ out := make(Locales, 0, len(locales))
+ for _, locale := range locales {
+ tag, err := language.Parse(locale)
+ if err == nil && !tag.IsRoot() {
+ out = append(out, tag)
+ }
+ }
+ return out
+}
+
+func (l Locales) String() string {
+ tags := make([]string, len(l))
+ for i, tag := range l {
+ tags[i] = tag.String()
+ }
+ return strings.Join(tags, " ")
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+// It decodes an unquoted space seperated string into Locales.
+// Undefined language tags in the input are ignored and ommited from
+// the resulting Locales.
+func (l *Locales) UnmarshalText(text []byte) error {
+ *l = ParseLocales(
+ strings.Split(string(text), " "),
+ )
+ return nil
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+// It decodes a json array or a space seperated string into Locales.
+// Undefined language tags in the input are ignored and ommited from
+// the resulting Locales.
+func (l *Locales) UnmarshalJSON(data []byte) error {
+ var dst any
+ if err := json.Unmarshal(data, &dst); err != nil {
+ return fmt.Errorf("oidc locales: %w", err)
+ }
+
+ // We catch the posibility of a space seperated string here,
+ // because UnmarshalText might have been implicetely called
+ // by the json library before we added UnmarshalJSON.
+ switch v := dst.(type) {
+ case nil:
+ *l = nil
+ case string:
+ *l = ParseLocales(strings.Split(v, " "))
+ case []any:
+ locales, err := gu.AssertInterfaces[string](v)
+ if err != nil {
+ return fmt.Errorf("oidc locales: %w", err)
+ }
+ *l = ParseLocales(locales)
+ default:
+ return fmt.Errorf("oidc locales: unsupported type: %T", v)
+ }
+ return nil
+}
+
+type MaxAge *uint
+
+func NewMaxAge(i uint) MaxAge {
+ return &i
+}
+
+type SpaceDelimitedArray []string
+
+type Prompt SpaceDelimitedArray
+
+type ResponseType string
+
+type ResponseMode string
+
+func (s SpaceDelimitedArray) String() string {
+ return strings.Join(s, " ")
+}
+
+func (s *SpaceDelimitedArray) UnmarshalText(text []byte) error {
+ *s = strings.Split(string(text), " ")
+ return nil
+}
+
+func (s SpaceDelimitedArray) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+func (s SpaceDelimitedArray) MarshalJSON() ([]byte, error) {
+ return json.Marshal((s).String())
+}
+
+func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error {
+ var str string
+ if err := json.Unmarshal(data, &str); err != nil {
+ return err
+ }
+ *s = strings.Split(str, " ")
+ return nil
+}
+
+func (s *SpaceDelimitedArray) Scan(src any) error {
+ if src == nil {
+ *s = nil
+ return nil
+ }
+ switch v := src.(type) {
+ case string:
+ if len(v) == 0 {
+ *s = SpaceDelimitedArray{}
+ return nil
+ }
+ *s = strings.Split(v, " ")
+ case []byte:
+ if len(v) == 0 {
+ *s = SpaceDelimitedArray{}
+ return nil
+ }
+ *s = strings.Split(string(v), " ")
+ default:
+ return fmt.Errorf("cannot convert %T to SpaceDelimitedArray", src)
+ }
+ return nil
+}
+
+func (s SpaceDelimitedArray) Value() (driver.Value, error) {
+ return strings.Join(s, " "), nil
+}
+
+// NewEncoder returns a schema Encoder with
+// a registered encoder for SpaceDelimitedArray.
+func NewEncoder() *schema.Encoder {
+ e := schema.NewEncoder()
+ e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
+ return value.Interface().(SpaceDelimitedArray).String()
+ })
+ e.RegisterEncoder(Locales{}, func(value reflect.Value) string {
+ return value.Interface().(Locales).String()
+ })
+ return e
+}
+
+type Time int64
+
+func (ts Time) AsTime() time.Time {
+ if ts == 0 {
+ return time.Time{}
+ }
+ return time.Unix(int64(ts), 0)
+}
+
+func FromTime(tt time.Time) Time {
+ if tt.IsZero() {
+ return 0
+ }
+ return Time(tt.Unix())
+}
+
+func NowTime() Time {
+ return FromTime(time.Now())
+}
+
+func (ts *Time) UnmarshalJSON(data []byte) error {
+ var v any
+ if err := json.Unmarshal(data, &v); err != nil {
+ return fmt.Errorf("oidc.Time: %w", err)
+ }
+ switch x := v.(type) {
+ case float64:
+ *ts = Time(x)
+ case string:
+ // Compatibility with Auth0:
+ // https://github.com/zitadel/oidc/issues/292
+ tt, err := time.Parse(time.RFC3339, x)
+ if err != nil {
+ return fmt.Errorf("oidc.Time: %w", err)
+ }
+ *ts = FromTime(tt)
+ case nil:
+ *ts = 0
+ default:
+ return fmt.Errorf("oidc.Time: unable to parse type %T with value %v", x, x)
+ }
+ return nil
+}
+
+type RequestObject struct {
+ Issuer string `json:"iss"`
+ Audience Audience `json:"aud"`
+ AuthRequest
+}
+
+func (r *RequestObject) GetIssuer() string {
+ return r.Issuer
+}
+
+func (*RequestObject) SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm) {}
diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go
new file mode 100644
index 0000000..53a9779
--- /dev/null
+++ b/pkg/oidc/types_test.go
@@ -0,0 +1,705 @@
+package oidc
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/schema"
+ "golang.org/x/text/language"
+)
+
+func TestAudience_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ audience Audience
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "invalid value",
+ args{
+ []byte(`{"aud": {"a": }}}`),
+ },
+ res{},
+ true,
+ },
+ {
+ "single audience",
+ args{
+ []byte(`{"aud": "single audience"}`),
+ },
+ res{
+ []string{"single audience"},
+ },
+ false,
+ },
+ {
+ "multiple audience",
+ args{
+ []byte(`{"aud": ["multiple", "audience"]}`),
+ },
+ res{
+ []string{"multiple", "audience"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := new(struct {
+ Audience Audience `json:"aud"`
+ })
+ if err := json.Unmarshal(tt.args.text, &a); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, a.Audience, tt.res.audience)
+ })
+ }
+}
+
+func TestDisplay_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ display Display
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "page",
+ args{
+ []byte("page"),
+ },
+ res{DisplayPage},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var d Display
+ if err := d.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if d != tt.res.display {
+ t.Errorf("Display is not correct is = %v, want %v", d, tt.res.display)
+ }
+ })
+ }
+}
+
+func TestLocale_Tag(t *testing.T) {
+ tests := []struct {
+ name string
+ l *Locale
+ want language.Tag
+ }{
+ {
+ name: "nil",
+ l: nil,
+ want: language.Und,
+ },
+ {
+ name: "Und",
+ l: NewLocale(language.Und),
+ want: language.Und,
+ },
+ {
+ name: "language",
+ l: NewLocale(language.Afrikaans),
+ want: language.Afrikaans,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, tt.l.Tag())
+ })
+ }
+}
+
+func TestLocale_String(t *testing.T) {
+ tests := []struct {
+ name string
+ l *Locale
+ want language.Tag
+ }{
+ {
+ name: "nil",
+ l: nil,
+ want: language.Und,
+ },
+ {
+ name: "Und",
+ l: NewLocale(language.Und),
+ want: language.Und,
+ },
+ {
+ name: "language",
+ l: NewLocale(language.Afrikaans),
+ want: language.Afrikaans,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want.String(), tt.l.String())
+ })
+ }
+}
+
+func TestLocale_MarshalJSON(t *testing.T) {
+ tests := []struct {
+ name string
+ l *Locale
+ want string
+ wantErr bool
+ }{
+ {
+ name: "nil",
+ l: nil,
+ want: "null",
+ },
+ {
+ name: "und",
+ l: NewLocale(language.Und),
+ want: "null",
+ },
+ {
+ name: "language",
+ l: NewLocale(language.Afrikaans),
+ want: `"af"`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := json.Marshal(tt.l)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, string(got))
+ })
+ }
+}
+
+func TestLocale_UnmarshalJSON(t *testing.T) {
+ type dst struct {
+ Locale *Locale `json:"locale,omitempty"`
+ }
+ tests := []struct {
+ name string
+ input string
+ want dst
+ wantErr bool
+ }{
+ {
+ name: "value not present",
+ input: `{}`,
+ wantErr: false,
+ want: dst{
+ Locale: nil,
+ },
+ },
+ {
+ name: "null",
+ input: `{"locale": null}`,
+ wantErr: false,
+ want: dst{
+ Locale: nil,
+ },
+ },
+ {
+ name: "empty, ignored",
+ input: `{"locale": ""}`,
+ wantErr: false,
+ want: dst{
+ Locale: &Locale{},
+ },
+ },
+ {
+ name: "afrikaans, ok",
+ input: `{"locale": "af"}`,
+ want: dst{
+ Locale: NewLocale(language.Afrikaans),
+ },
+ },
+ {
+ name: "gb, ignored",
+ input: `{"locale": "gb"}`,
+ want: dst{
+ Locale: &Locale{},
+ },
+ },
+ {
+ name: "bad form, error",
+ input: `{"locale": "g!!!!!"}`,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var got dst
+ err := json.Unmarshal([]byte(tt.input), &got)
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestParseLocales(t *testing.T) {
+ in := []string{language.Afrikaans.String(), language.Danish.String(), "foobar", language.Und.String()}
+ want := Locales{language.Afrikaans, language.Danish}
+ got := ParseLocales(in)
+ assert.ElementsMatch(t, want, got)
+}
+
+func TestLocales_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ tags []language.Tag
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{},
+ false,
+ },
+ {
+ "undefined",
+ args{
+ []byte("und"),
+ },
+ res{},
+ false,
+ },
+ {
+ "single language",
+ args{
+ []byte("de"),
+ },
+ res{[]language.Tag{language.German}},
+ false,
+ },
+ {
+ "multiple languages",
+ args{
+ []byte("de en"),
+ },
+ res{[]language.Tag{language.German, language.English}},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var locales Locales
+ if err := locales.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, locales, tt.res.tags)
+ })
+ }
+}
+
+func TestLocales_UnmarshalJSON(t *testing.T) {
+ in := []string{language.Afrikaans.String(), language.Danish.String(), "foobar", language.Und.String()}
+ spaceSepStr := strconv.Quote(strings.Join(in, " "))
+ jsonArray, err := json.Marshal(in)
+ require.NoError(t, err)
+
+ out := Locales{language.Afrikaans, language.Danish}
+
+ type args struct {
+ data []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ want Locales
+ wantErr bool
+ }{
+ {
+ name: "invalid JSON",
+ args: args{
+ data: []byte("~~~"),
+ },
+ wantErr: true,
+ },
+ {
+ name: "null",
+ args: args{
+ data: []byte("null"),
+ },
+ want: nil,
+ },
+ {
+ name: "space seperated string",
+ args: args{
+ data: []byte(spaceSepStr),
+ },
+ want: out,
+ },
+ {
+ name: "json string array",
+ args: args{
+ data: jsonArray,
+ },
+ want: out,
+ },
+ {
+ name: "json invalid array",
+ args: args{
+ data: []byte(`[1,2,3]`),
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid type (float64)",
+ args: args{
+ data: []byte("22"),
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var got Locales
+ err := got.UnmarshalJSON([]byte(tt.args.data))
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestScopes_UnmarshalText(t *testing.T) {
+ type args struct {
+ text []byte
+ }
+ type res struct {
+ scopes []string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ []byte("unknown"),
+ },
+ res{
+ []string{"unknown"},
+ },
+ false,
+ },
+ {
+ "struct",
+ args{
+ []byte(`{"unknown":"value"}`),
+ },
+ res{
+ []string{`{"unknown":"value"}`},
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ []byte("openid"),
+ },
+ res{
+ []string{"openid"},
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ []byte("openid email custom:scope"),
+ },
+ res{
+ []string{"openid", "email", "custom:scope"},
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var scopes SpaceDelimitedArray
+ if err := scopes.UnmarshalText(tt.args.text); (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
+ })
+ }
+}
+
+func TestScopes_MarshalText(t *testing.T) {
+ type args struct {
+ scopes SpaceDelimitedArray
+ }
+ type res struct {
+ scopes []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ wantErr bool
+ }{
+ {
+ "unknown value",
+ args{
+ SpaceDelimitedArray{"unknown"},
+ },
+ res{
+ []byte("unknown"),
+ },
+ false,
+ },
+ {
+ "struct",
+ args{
+ SpaceDelimitedArray{`{"unknown":"value"}`},
+ },
+ res{
+ []byte(`{"unknown":"value"}`),
+ },
+ false,
+ },
+ {
+ "openid",
+ args{
+ SpaceDelimitedArray{"openid"},
+ },
+ res{
+ []byte("openid"),
+ },
+ false,
+ },
+ {
+ "multiple scopes",
+ args{
+ SpaceDelimitedArray{"openid", "email", "custom:scope"},
+ },
+ res{
+ []byte("openid email custom:scope"),
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ text, err := tt.args.scopes.MarshalText()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("MarshalText() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !bytes.Equal(text, tt.res.scopes) {
+ t.Errorf("MarshalText() is = %q, want %q", text, tt.res.scopes)
+ }
+ })
+ }
+}
+
+func TestSpaceDelimitatedArray_ValuerNotNil(t *testing.T) {
+ inputs := [][]string{
+ {"two", "elements"},
+ {"one"},
+ { /*zero*/ },
+ }
+ for _, input := range inputs {
+ t.Run(strconv.Itoa(len(input))+strings.Join(input, "_"), func(t *testing.T) {
+ sda := SpaceDelimitedArray(input)
+ dbValue, err := sda.Value()
+ if !assert.NoError(t, err, "Value") {
+ return
+ }
+ var reversed SpaceDelimitedArray
+ err = reversed.Scan(dbValue)
+ if assert.NoError(t, err, "Scan string") {
+ assert.Equal(t, sda, reversed, "scan string")
+ }
+ reversed = nil
+ dbValueString, ok := dbValue.(string)
+ if assert.True(t, ok, "dbValue is string") {
+ err = reversed.Scan([]byte(dbValueString))
+ if assert.NoError(t, err, "Scan bytes") {
+ assert.Equal(t, sda, reversed, "scan bytes")
+ }
+ }
+ })
+ }
+}
+
+func TestSpaceDelimitatedArray_ValuerNil(t *testing.T) {
+ var reversed SpaceDelimitedArray
+ err := reversed.Scan(nil)
+ if assert.NoError(t, err, "Scan nil") {
+ assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil")
+ }
+}
+
+func TestNewEncoder(t *testing.T) {
+ type request struct {
+ Scopes SpaceDelimitedArray `schema:"scope"`
+ }
+ a := request{
+ Scopes: SpaceDelimitedArray{"foo", "bar"},
+ }
+
+ values := make(url.Values)
+ NewEncoder().Encode(a, values)
+ assert.Equal(t, url.Values{"scope": []string{"foo bar"}}, values)
+
+ var b request
+ schema.NewDecoder().Decode(&b, values)
+ assert.Equal(t, a, b)
+}
+
+func TestTime_AsTime(t *testing.T) {
+ tests := []struct {
+ name string
+ ts Time
+ want time.Time
+ }{
+ {
+ name: "unset",
+ ts: 0,
+ want: time.Time{},
+ },
+ {
+ name: "set",
+ ts: 1,
+ want: time.Unix(1, 0),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.ts.AsTime()
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestTime_FromTime(t *testing.T) {
+ tests := []struct {
+ name string
+ tt time.Time
+ want Time
+ }{
+ {
+ name: "zero",
+ tt: time.Time{},
+ want: 0,
+ },
+ {
+ name: "set",
+ tt: time.Unix(1, 0),
+ want: 1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := FromTime(tt.tt)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestTime_UnmarshalJSON(t *testing.T) {
+ type dst struct {
+ UpdatedAt Time `json:"updated_at"`
+ }
+ tests := []struct {
+ name string
+ json string
+ want dst
+ wantErr bool
+ }{
+ {
+ name: "RFC3339", // https://github.com/zitadel/oidc/issues/292
+ json: `{"updated_at": "2021-05-11T21:13:25.566Z"}`,
+ want: dst{UpdatedAt: 1620767605},
+ },
+ {
+ name: "int",
+ json: `{"updated_at":1620767605}`,
+ want: dst{UpdatedAt: 1620767605},
+ },
+ {
+ name: "time parse error",
+ json: `{"updated_at":"foo"}`,
+ wantErr: true,
+ },
+ {
+ name: "null",
+ json: `{"updated_at":null}`,
+ },
+ {
+ name: "invalid type",
+ json: `{"updated_at":["foo","bar"]}`,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var got dst
+ err := json.Unmarshal([]byte(tt.json), &got)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, got)
+ })
+ }
+ t.Run("syntax error", func(t *testing.T) {
+ var ts Time
+ err := ts.UnmarshalJSON([]byte{'~'})
+ assert.Error(t, err)
+ })
+}
diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go
index d0fe4a8..ef8ebe4 100644
--- a/pkg/oidc/userinfo.go
+++ b/pkg/oidc/userinfo.go
@@ -1,90 +1,91 @@
package oidc
-import (
- "encoding/json"
- "time"
+// UserInfo implements OpenID Connect Core 1.0, section 5.1.
+// https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims.
+type UserInfo struct {
+ Subject string `json:"sub,omitempty"`
+ UserInfoProfile
+ UserInfoEmail
+ UserInfoPhone
+ Address *UserInfoAddress `json:"address,omitempty"`
- "golang.org/x/text/language"
-)
-
-type Userinfo struct {
- Subject string
- UserinfoProfile
- UserinfoEmail
- UserinfoPhone
- Address *UserinfoAddress
-
- Authorizations []string
-
- claims map[string]interface{}
+ Claims map[string]any `json:"-"`
}
-type UserinfoProfile struct {
- Name string
- GivenName string
- FamilyName string
- MiddleName string
- Nickname string
- Profile string
- Picture string
- Website string
- Gender Gender
- Birthdate string
- Zoneinfo string
- Locale language.Tag
- UpdatedAt time.Time
- PreferredUsername string
+func (u *UserInfo) AppendClaims(k string, v any) {
+ if u.Claims == nil {
+ u.Claims = make(map[string]any)
+ }
+
+ u.Claims[k] = v
}
-type Gender string
-
-type UserinfoEmail struct {
- Email string
- EmailVerified bool
+// GetAddress is a safe getter that takes
+// care of a possible nil value.
+func (u *UserInfo) GetAddress() *UserInfoAddress {
+ if u.Address == nil {
+ return new(UserInfoAddress)
+ }
+ return u.Address
}
-type UserinfoPhone struct {
- PhoneNumber string
- PhoneNumberVerified bool
+// GetSubject implements [rp.SubjectGetter]
+func (u *UserInfo) GetSubject() string {
+ return u.Subject
}
-type UserinfoAddress struct {
- Formatted string
- StreetAddress string
- Locality string
- Region string
- PostalCode string
- Country string
+type uiAlias UserInfo
+
+func (u *UserInfo) MarshalJSON() ([]byte, error) {
+ return mergeAndMarshalClaims((*uiAlias)(u), u.Claims)
}
-type jsonUserinfoProfile struct {
- Name string `json:"name,omitempty"`
- GivenName string `json:"given_name,omitempty"`
- FamilyName string `json:"family_name,omitempty"`
- MiddleName string `json:"middle_name,omitempty"`
- Nickname string `json:"nickname,omitempty"`
- Profile string `json:"profile,omitempty"`
- Picture string `json:"picture,omitempty"`
- Website string `json:"website,omitempty"`
- Gender string `json:"gender,omitempty"`
- Birthdate string `json:"birthdate,omitempty"`
- Zoneinfo string `json:"zoneinfo,omitempty"`
- Locale string `json:"locale,omitempty"`
- UpdatedAt int64 `json:"updated_at,omitempty"`
- PreferredUsername string `json:"preferred_username,omitempty"`
+func (u *UserInfo) UnmarshalJSON(data []byte) error {
+ return unmarshalJSONMulti(data, (*uiAlias)(u), &u.Claims)
}
-type jsonUserinfoEmail struct {
- Email string `json:"email,omitempty"`
- EmailVerified bool `json:"email_verified,omitempty"`
+type UserInfoProfile struct {
+ Name string `json:"name,omitempty"`
+ GivenName string `json:"given_name,omitempty"`
+ FamilyName string `json:"family_name,omitempty"`
+ MiddleName string `json:"middle_name,omitempty"`
+ Nickname string `json:"nickname,omitempty"`
+ Profile string `json:"profile,omitempty"`
+ Picture string `json:"picture,omitempty"`
+ Website string `json:"website,omitempty"`
+ Gender Gender `json:"gender,omitempty"`
+ Birthdate string `json:"birthdate,omitempty"`
+ Zoneinfo string `json:"zoneinfo,omitempty"`
+ Locale *Locale `json:"locale,omitempty"`
+ UpdatedAt Time `json:"updated_at,omitempty"`
+ PreferredUsername string `json:"preferred_username,omitempty"`
}
-type jsonUserinfoPhone struct {
- Phone string `json:"phone_number,omitempty"`
- PhoneVerified bool `json:"phone_number_verified,omitempty"`
+type UserInfoEmail struct {
+ Email string `json:"email,omitempty"`
+
+ // Handle providers that return email_verified as a string
+ // https://forums.aws.amazon.com/thread.jspa?messageID=949441
+ // https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
+ EmailVerified Bool `json:"email_verified,omitempty"`
}
-type jsonUserinfoAddress struct {
+type Bool bool
+
+func (bs *Bool) UnmarshalJSON(data []byte) error {
+ if string(data) == "true" || string(data) == `"true"` {
+ *bs = true
+ }
+
+ return nil
+}
+
+type UserInfoPhone struct {
+ PhoneNumber string `json:"phone_number,omitempty"`
+ PhoneNumberVerified bool `json:"phone_number_verified,omitempty"`
+}
+
+type UserInfoAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address,omitempty"`
Locality string `json:"locality,omitempty"`
@@ -93,78 +94,6 @@ type jsonUserinfoAddress struct {
Country string `json:"country,omitempty"`
}
-func (i *Userinfo) MarshalJSON() ([]byte, error) {
- j := new(jsonUserinfo)
- j.Subject = i.Subject
- j.setUserinfo(*i)
- j.Authorizations = i.Authorizations
- return json.Marshal(j)
-}
-
-func (i *Userinfo) UnmmarshalJSON(data []byte) error {
- if err := json.Unmarshal(data, i); err != nil {
- return err
- }
- return json.Unmarshal(data, i.claims)
-}
-
-type jsonUserinfo struct {
- Subject string `json:"sub,omitempty"`
- jsonUserinfoProfile
- jsonUserinfoEmail
- jsonUserinfoPhone
- JsonUserinfoAddress *jsonUserinfoAddress `json:"address,omitempty"`
- Authorizations []string `json:"authorizations,omitempty"`
-}
-
-func (j *jsonUserinfo) setUserinfo(i Userinfo) {
- j.setUserinfoProfile(i.UserinfoProfile)
- j.setUserinfoEmail(i.UserinfoEmail)
- j.setUserinfoPhone(i.UserinfoPhone)
- j.setUserinfoAddress(i.Address)
-}
-
-func (j *jsonUserinfo) setUserinfoProfile(i UserinfoProfile) {
- j.Name = i.Name
- j.GivenName = i.GivenName
- j.FamilyName = i.FamilyName
- j.MiddleName = i.MiddleName
- j.Nickname = i.Nickname
- j.Profile = i.Profile
- j.Picture = i.Picture
- j.Website = i.Website
- j.Gender = string(i.Gender)
- j.Birthdate = i.Birthdate
- j.Zoneinfo = i.Zoneinfo
- if i.Locale != language.Und {
- j.Locale = i.Locale.String()
- }
- j.UpdatedAt = timeToJSON(i.UpdatedAt)
- j.PreferredUsername = i.PreferredUsername
-}
-
-func (j *jsonUserinfo) setUserinfoEmail(i UserinfoEmail) {
- j.Email = i.Email
- j.EmailVerified = i.EmailVerified
-}
-
-func (j *jsonUserinfo) setUserinfoPhone(i UserinfoPhone) {
- j.Phone = i.PhoneNumber
- j.PhoneVerified = i.PhoneNumberVerified
-}
-
-func (j *jsonUserinfo) setUserinfoAddress(i *UserinfoAddress) {
- if i == nil {
- return
- }
- j.JsonUserinfoAddress.Country = i.Country
- j.JsonUserinfoAddress.Formatted = i.Formatted
- j.JsonUserinfoAddress.Locality = i.Locality
- j.JsonUserinfoAddress.PostalCode = i.PostalCode
- j.JsonUserinfoAddress.Region = i.Region
- j.JsonUserinfoAddress.StreetAddress = i.StreetAddress
-}
-
type UserInfoRequest struct {
AccessToken string `schema:"access_token"`
}
diff --git a/pkg/oidc/userinfo_test.go b/pkg/oidc/userinfo_test.go
new file mode 100644
index 0000000..a574366
--- /dev/null
+++ b/pkg/oidc/userinfo_test.go
@@ -0,0 +1,119 @@
+package oidc
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestUserInfo_AppendClaims(t *testing.T) {
+ u := new(UserInfo)
+ u.AppendClaims("a", "b")
+ want := map[string]any{"a": "b"}
+ assert.Equal(t, want, u.Claims)
+
+ u.AppendClaims("d", "e")
+ want["d"] = "e"
+ assert.Equal(t, want, u.Claims)
+}
+
+func TestUserInfo_GetAddress(t *testing.T) {
+ // nil address
+ u := new(UserInfo)
+ assert.Equal(t, &UserInfoAddress{}, u.GetAddress())
+
+ u.Address = &UserInfoAddress{PostalCode: "1234"}
+ assert.Equal(t, u.Address, u.GetAddress())
+}
+
+func TestUserInfoMarshal(t *testing.T) {
+ userinfo := &UserInfo{
+ Subject: "test",
+ Address: &UserInfoAddress{
+ StreetAddress: "Test 789\nPostfach 2",
+ },
+ UserInfoEmail: UserInfoEmail{
+ Email: "test",
+ EmailVerified: true,
+ },
+ UserInfoPhone: UserInfoPhone{
+ PhoneNumber: "0791234567",
+ PhoneNumberVerified: true,
+ },
+ UserInfoProfile: UserInfoProfile{
+ Name: "Test",
+ },
+ Claims: map[string]any{"private_claim": "test"},
+ }
+
+ marshal, err := json.Marshal(userinfo)
+ assert.NoError(t, err)
+
+ out := new(UserInfo)
+ assert.NoError(t, json.Unmarshal(marshal, out))
+ expected, err := json.Marshal(out)
+
+ assert.NoError(t, err)
+ assert.Equal(t, expected, marshal)
+
+ out2 := new(UserInfo)
+ assert.NoError(t, json.Unmarshal(expected, out2))
+ assert.Equal(t, out, out2)
+}
+
+func TestUserInfoEmailVerifiedUnmarshal(t *testing.T) {
+ t.Parallel()
+
+ t.Run("unmarshal email_verified from json bool true", func(t *testing.T) {
+ jsonBool := []byte(`{"email": "my@email.com", "email_verified": true}`)
+
+ var uie UserInfoEmail
+
+ err := json.Unmarshal(jsonBool, &uie)
+ assert.NoError(t, err)
+ assert.Equal(t, UserInfoEmail{
+ Email: "my@email.com",
+ EmailVerified: true,
+ }, uie)
+ })
+
+ t.Run("unmarshal email_verified from json string true", func(t *testing.T) {
+ jsonBool := []byte(`{"email": "my@email.com", "email_verified": "true"}`)
+
+ var uie UserInfoEmail
+
+ err := json.Unmarshal(jsonBool, &uie)
+ assert.NoError(t, err)
+ assert.Equal(t, UserInfoEmail{
+ Email: "my@email.com",
+ EmailVerified: true,
+ }, uie)
+ })
+
+ t.Run("unmarshal email_verified from json bool false", func(t *testing.T) {
+ jsonBool := []byte(`{"email": "my@email.com", "email_verified": false}`)
+
+ var uie UserInfoEmail
+
+ err := json.Unmarshal(jsonBool, &uie)
+ assert.NoError(t, err)
+ assert.Equal(t, UserInfoEmail{
+ Email: "my@email.com",
+ EmailVerified: false,
+ }, uie)
+ })
+
+ t.Run("unmarshal email_verified from json string false", func(t *testing.T) {
+ jsonBool := []byte(`{"email": "my@email.com", "email_verified": "false"}`)
+
+ var uie UserInfoEmail
+
+ err := json.Unmarshal(jsonBool, &uie)
+ assert.NoError(t, err)
+ assert.Equal(t, UserInfoEmail{
+ Email: "my@email.com",
+ EmailVerified: false,
+ }, uie)
+ })
+}
diff --git a/pkg/oidc/util.go b/pkg/oidc/util.go
new file mode 100644
index 0000000..462ea44
--- /dev/null
+++ b/pkg/oidc/util.go
@@ -0,0 +1,54 @@
+package oidc
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+)
+
+// mergeAndMarshalClaims merges registered and the custom
+// claims map into a single JSON object.
+// Registered fields overwrite custom claims.
+func mergeAndMarshalClaims(registered any, extraClaims map[string]any) ([]byte, error) {
+ // Use a buffer for memory re-use, instead off letting
+ // json allocate a new []byte for every step.
+ buf := new(bytes.Buffer)
+
+ // Marshal the registered claims into JSON
+ if err := json.NewEncoder(buf).Encode(registered); err != nil {
+ return nil, fmt.Errorf("oidc registered claims: %w", err)
+ }
+
+ if len(extraClaims) > 0 {
+ merged := make(map[string]any)
+ for k, v := range extraClaims {
+ merged[k] = v
+ }
+
+ // Merge JSON data into custom claims.
+ // The full-read action by the decoder resets the buffer
+ // to zero len, while retaining underlaying cap.
+ if err := json.NewDecoder(buf).Decode(&merged); err != nil {
+ return nil, fmt.Errorf("oidc registered claims: %w", err)
+ }
+
+ // Marshal the final result.
+ if err := json.NewEncoder(buf).Encode(merged); err != nil {
+ return nil, fmt.Errorf("oidc custom claims: %w", err)
+ }
+ }
+
+ return buf.Bytes(), nil
+}
+
+// unmarshalJSONMulti unmarshals the same JSON data into multiple destinations.
+// Each destination must be a pointer, as per json.Unmarshal rules.
+// Returns on the first error and destinations may be partly filled with data.
+func unmarshalJSONMulti(data []byte, destinations ...any) error {
+ for _, dst := range destinations {
+ if err := json.Unmarshal(data, dst); err != nil {
+ return fmt.Errorf("oidc: %w into %T", err, dst)
+ }
+ }
+ return nil
+}
diff --git a/pkg/oidc/util_test.go b/pkg/oidc/util_test.go
new file mode 100644
index 0000000..6363d83
--- /dev/null
+++ b/pkg/oidc/util_test.go
@@ -0,0 +1,147 @@
+package oidc
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type jsonErrorTest struct{}
+
+func (jsonErrorTest) MarshalJSON() ([]byte, error) {
+ return nil, errors.New("test")
+}
+
+func Test_mergeAndMarshalClaims(t *testing.T) {
+ type args struct {
+ registered any
+ claims map[string]any
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ wantErr bool
+ }{
+ {
+ name: "encoder error",
+ args: args{
+ registered: jsonErrorTest{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "no claims",
+ args: args{
+ registered: struct {
+ Foo string `json:"foo,omitempty"`
+ }{
+ Foo: "bar",
+ },
+ },
+ want: "{\"foo\":\"bar\"}\n",
+ },
+ {
+ name: "with claims",
+ args: args{
+ registered: struct {
+ Foo string `json:"foo,omitempty"`
+ }{
+ Foo: "bar",
+ },
+ claims: map[string]any{
+ "bar": "foo",
+ },
+ },
+ want: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
+ },
+ {
+ name: "registered overwrites custom",
+ args: args{
+ registered: struct {
+ Foo string `json:"foo,omitempty"`
+ }{
+ Foo: "bar",
+ },
+ claims: map[string]any{
+ "foo": "Hello, World!",
+ },
+ },
+ want: "{\"foo\":\"bar\"}\n",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := mergeAndMarshalClaims(tt.args.registered, tt.args.claims)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, string(got))
+ })
+ }
+}
+
+func Test_unmarshalJSONMulti(t *testing.T) {
+ type dst struct {
+ Foo string `json:"foo,omitempty"`
+ }
+
+ type args struct {
+ data string
+ destinations []any
+ }
+ tests := []struct {
+ name string
+ args args
+ want []any
+ wantErr bool
+ }{
+ {
+ name: "error",
+ args: args{
+ data: "~!~~",
+ destinations: []any{
+ &dst{},
+ &map[string]any{},
+ },
+ },
+ want: []any{
+ &dst{},
+ &map[string]any{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "success",
+ args: args{
+ data: "{\"bar\":\"foo\",\"foo\":\"bar\"}\n",
+ destinations: []any{
+ &dst{},
+ &map[string]any{},
+ },
+ },
+ want: []any{
+ &dst{Foo: "bar"},
+ &map[string]any{
+ "foo": "bar",
+ "bar": "foo",
+ },
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := unmarshalJSONMulti([]byte(tt.args.data), tt.args.destinations...)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, tt.args.destinations)
+ })
+ }
+}
diff --git a/pkg/oidc/verifier.go b/pkg/oidc/verifier.go
new file mode 100644
index 0000000..d5e0213
--- /dev/null
+++ b/pkg/oidc/verifier.go
@@ -0,0 +1,250 @@
+package oidc
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "slices"
+ "strings"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+)
+
+type Claims interface {
+ GetIssuer() string
+ GetSubject() string
+ GetAudience() []string
+ GetExpiration() time.Time
+ GetIssuedAt() time.Time
+ GetNonce() string
+ GetAuthenticationContextClassReference() string
+ GetAuthTime() time.Time
+ GetAuthorizedParty() string
+ ClaimsSignature
+}
+
+type ClaimsSignature interface {
+ SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
+}
+
+type IDClaims interface {
+ Claims
+ GetSignatureAlgorithm() jose.SignatureAlgorithm
+ GetAccessTokenHash() string
+}
+
+var (
+ ErrParse = errors.New("parsing of request failed")
+ ErrIssuerInvalid = errors.New("issuer does not match")
+ ErrDiscoveryFailed = errors.New("OpenID Provider Configuration Discovery has failed")
+ ErrSubjectMissing = errors.New("subject missing")
+ ErrAudience = errors.New("audience is not valid")
+ ErrAzpMissing = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
+ ErrAzpInvalid = errors.New("authorized party is not valid")
+ ErrSignatureMissing = errors.New("id_token does not contain a signature")
+ ErrSignatureMultiple = errors.New("id_token contains multiple signatures")
+ ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
+ ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
+ ErrSignatureInvalid = errors.New("invalid signature")
+ ErrExpired = errors.New("token has expired")
+ ErrIatMissing = errors.New("issuedAt of token is missing")
+ ErrIatInFuture = errors.New("issuedAt of token is in the future")
+ ErrIatToOld = errors.New("issuedAt of token is to old")
+ ErrNonceInvalid = errors.New("nonce does not match")
+ ErrAcrInvalid = errors.New("acr is invalid")
+ ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
+ ErrAuthTimeToOld = errors.New("auth time of token is too old")
+ ErrAtHash = errors.New("at_hash does not correspond to access token")
+)
+
+// Verifier caries configuration for the various token verification
+// functions. Use package specific constructor functions to know
+// which values need to be set.
+type Verifier struct {
+ Issuer string
+ MaxAgeIAT time.Duration
+ Offset time.Duration
+ ClientID string
+ SupportedSignAlgs []string
+ MaxAge time.Duration
+ ACR ACRVerifier
+ KeySet KeySet
+ Nonce func(ctx context.Context) string
+}
+
+// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
+type ACRVerifier func(string) error
+
+// DefaultACRVerifier implements `ACRVerifier` returning an error
+// if none of the provided values matches the acr claim
+func DefaultACRVerifier(possibleValues []string) ACRVerifier {
+ return func(acr string) error {
+ if !slices.Contains(possibleValues, acr) {
+ return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
+ }
+ return nil
+ }
+}
+
+func DecryptToken(tokenString string) (string, error) {
+ return tokenString, nil // TODO: impl
+}
+
+func ParseToken(tokenString string, claims any) ([]byte, error) {
+ parts := strings.Split(tokenString, ".")
+ if len(parts) != 3 {
+ return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse)
+ }
+ payload, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err)
+ }
+ err = json.Unmarshal(payload, claims)
+ return payload, err
+}
+
+func CheckSubject(claims Claims) error {
+ if claims.GetSubject() == "" {
+ return ErrSubjectMissing
+ }
+ return nil
+}
+
+func CheckIssuer(claims Claims, issuer string) error {
+ if claims.GetIssuer() != issuer {
+ return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
+ }
+ return nil
+}
+
+func CheckAudience(claims Claims, clientID string) error {
+ if !slices.Contains(claims.GetAudience(), clientID) {
+ return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
+ }
+
+ // TODO: check aud trusted
+ return nil
+}
+
+// CheckAuthorizedParty checks azp (authorized party) claim requirements.
+//
+// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
+// If an azp Claim is present, the Client SHOULD verify that its client_id is the Claim Value.
+// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
+func CheckAuthorizedParty(claims Claims, clientID string) error {
+ if len(claims.GetAudience()) > 1 {
+ if claims.GetAuthorizedParty() == "" {
+ return ErrAzpMissing
+ }
+ }
+ if claims.GetAuthorizedParty() != "" && claims.GetAuthorizedParty() != clientID {
+ return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, claims.GetAuthorizedParty(), clientID)
+ }
+ return nil
+}
+
+func CheckSignature(ctx context.Context, token string, payload []byte, claims ClaimsSignature, supportedSigAlgs []string, set KeySet) error {
+ jws, err := jose.ParseSigned(token, toJoseSignatureAlgorithms(supportedSigAlgs))
+ if err != nil {
+ if strings.HasPrefix(err.Error(), "go-jose/go-jose: unexpected signature algorithm") {
+ // TODO(v4): we should wrap errors instead of returning static ones.
+ // This is a workaround so we keep returning the same error for now.
+ return ErrSignatureUnsupportedAlg
+ }
+ return ErrParse
+ }
+ if len(jws.Signatures) == 0 {
+ return ErrSignatureMissing
+ }
+ if len(jws.Signatures) > 1 {
+ return ErrSignatureMultiple
+ }
+ sig := jws.Signatures[0]
+
+ signedPayload, err := set.VerifySignature(ctx, jws)
+ if err != nil {
+ return fmt.Errorf("%w (%v)", ErrSignatureInvalid, err)
+ }
+
+ if !bytes.Equal(signedPayload, payload) {
+ return ErrSignatureInvalidPayload
+ }
+
+ claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))
+
+ return nil
+}
+
+// TODO(v4): Use the new jose.SignatureAlgorithm type directly, instead of string.
+func toJoseSignatureAlgorithms(algorithms []string) []jose.SignatureAlgorithm {
+ out := make([]jose.SignatureAlgorithm, len(algorithms))
+ for i := range algorithms {
+ out[i] = jose.SignatureAlgorithm(algorithms[i])
+ }
+ if len(out) == 0 {
+ out = append(out, jose.RS256, jose.ES256, jose.PS256)
+ }
+ return out
+}
+
+func CheckExpiration(claims Claims, offset time.Duration) error {
+ expiration := claims.GetExpiration()
+ if !time.Now().Add(offset).Before(expiration) {
+ return ErrExpired
+ }
+ return nil
+}
+
+func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
+ issuedAt := claims.GetIssuedAt()
+ if issuedAt.IsZero() {
+ return ErrIatMissing
+ }
+ nowWithOffset := time.Now().Add(offset).Round(time.Second)
+ if issuedAt.After(nowWithOffset) {
+ return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
+ }
+ if maxAgeIAT == 0 {
+ return nil
+ }
+ maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
+ if issuedAt.Before(maxAge) {
+ return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
+ }
+ return nil
+}
+
+func CheckNonce(claims Claims, nonce string) error {
+ if claims.GetNonce() != nonce {
+ return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
+ }
+ return nil
+}
+
+func CheckAuthorizationContextClassReference(claims Claims, acr ACRVerifier) error {
+ if acr != nil {
+ if err := acr(claims.GetAuthenticationContextClassReference()); err != nil {
+ return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
+ }
+ }
+ return nil
+}
+
+func CheckAuthTime(claims Claims, maxAge time.Duration) error {
+ if maxAge == 0 {
+ return nil
+ }
+ if claims.GetAuthTime().IsZero() {
+ return ErrAuthTimeNotPresent
+ }
+ authTime := claims.GetAuthTime()
+ maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
+ if authTime.Before(maxAuthTime) {
+ return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
+ }
+ return nil
+}
diff --git a/pkg/oidc/verifier_parse_test.go b/pkg/oidc/verifier_parse_test.go
new file mode 100644
index 0000000..9cf5c1e
--- /dev/null
+++ b/pkg/oidc/verifier_parse_test.go
@@ -0,0 +1,128 @@
+package oidc_test
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseToken(t *testing.T) {
+ token, wantClaims := tu.ValidIDToken()
+ wantClaims.SignatureAlg = "" // unset, because is not part of the JSON payload
+
+ wantPayload, err := json.Marshal(wantClaims)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ tokenString string
+ wantErr bool
+ }{
+ {
+ name: "split error",
+ tokenString: "nope",
+ wantErr: true,
+ },
+ {
+ name: "base64 error",
+ tokenString: "foo.~.bar",
+ wantErr: true,
+ },
+ {
+ name: "success",
+ tokenString: token,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotClaims := new(oidc.IDTokenClaims)
+ gotPayload, err := oidc.ParseToken(tt.tokenString, gotClaims)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, wantClaims, gotClaims)
+ assert.JSONEq(t, string(wantPayload), string(gotPayload))
+ })
+ }
+}
+
+func TestCheckSignature(t *testing.T) {
+ errCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ token, _ := tu.ValidIDToken()
+ payload, err := oidc.ParseToken(token, &oidc.IDTokenClaims{})
+ require.NoError(t, err)
+
+ type args struct {
+ ctx context.Context
+ token string
+ payload []byte
+ supportedSigAlgs []string
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr error
+ }{
+ {
+ name: "parse error",
+ args: args{
+ ctx: context.Background(),
+ token: "~",
+ payload: payload,
+ },
+ wantErr: oidc.ErrParse,
+ },
+ {
+ name: "default sigAlg",
+ args: args{
+ ctx: context.Background(),
+ token: token,
+ payload: payload,
+ },
+ },
+ {
+ name: "unsupported sigAlg",
+ args: args{
+ ctx: context.Background(),
+ token: token,
+ payload: payload,
+ supportedSigAlgs: []string{"foo", "bar"},
+ },
+ wantErr: oidc.ErrSignatureUnsupportedAlg,
+ },
+ {
+ name: "verify error",
+ args: args{
+ ctx: errCtx,
+ token: token,
+ payload: payload,
+ },
+ wantErr: oidc.ErrSignatureInvalid,
+ },
+ {
+ name: "inequal payloads",
+ args: args{
+ ctx: context.Background(),
+ token: token,
+ payload: []byte{0, 1, 2},
+ },
+ wantErr: oidc.ErrSignatureInvalidPayload,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ claims := new(oidc.TokenClaims)
+ err := oidc.CheckSignature(tt.args.ctx, tt.args.token, tt.args.payload, claims, tt.args.supportedSigAlgs, tu.KeySet{})
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
diff --git a/pkg/oidc/verifier_test.go b/pkg/oidc/verifier_test.go
new file mode 100644
index 0000000..93e7157
--- /dev/null
+++ b/pkg/oidc/verifier_test.go
@@ -0,0 +1,374 @@
+package oidc
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDecryptToken(t *testing.T) {
+ const tokenString = "ABC"
+ got, err := DecryptToken(tokenString)
+ require.NoError(t, err)
+ assert.Equal(t, tokenString, got)
+}
+
+func TestDefaultACRVerifier(t *testing.T) {
+ acrVerfier := DefaultACRVerifier([]string{"foo", "bar"})
+
+ tests := []struct {
+ name string
+ acr string
+ wantErr string
+ }{
+ {
+ name: "ok",
+ acr: "bar",
+ },
+ {
+ name: "error",
+ acr: "hello",
+ wantErr: "expected one of: [foo bar], got: \"hello\"",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := acrVerfier(tt.acr)
+ if tt.wantErr != "" {
+ assert.EqualError(t, err, tt.wantErr)
+ return
+ }
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestCheckSubject(t *testing.T) {
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrSubjectMissing,
+ },
+ {
+ name: "ok",
+ claims: &TokenClaims{
+ Subject: "foo",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckSubject(tt.claims)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckIssuer(t *testing.T) {
+ const issuer = "foo.bar"
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrIssuerInvalid,
+ },
+ {
+ name: "wrong",
+ claims: &TokenClaims{
+ Issuer: "wrong",
+ },
+ wantErr: ErrIssuerInvalid,
+ },
+ {
+ name: "ok",
+ claims: &TokenClaims{
+ Issuer: issuer,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckIssuer(tt.claims, issuer)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckAudience(t *testing.T) {
+ const clientID = "foo.bar"
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrAudience,
+ },
+ {
+ name: "wrong",
+ claims: &TokenClaims{
+ Audience: []string{"wrong"},
+ },
+ wantErr: ErrAudience,
+ },
+ {
+ name: "ok",
+ claims: &TokenClaims{
+ Audience: []string{clientID},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckAudience(tt.claims, clientID)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckAuthorizedParty(t *testing.T) {
+ const clientID = "foo.bar"
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "single audience, no azp",
+ claims: &TokenClaims{
+ Audience: []string{clientID},
+ },
+ },
+ {
+ name: "multiple audience, no azp",
+ claims: &TokenClaims{
+ Audience: []string{clientID, "other"},
+ },
+ wantErr: ErrAzpMissing,
+ },
+ {
+ name: "single audience, with azp",
+ claims: &TokenClaims{
+ Audience: []string{clientID},
+ AuthorizedParty: clientID,
+ },
+ },
+ {
+ name: "multiple audience, with azp",
+ claims: &TokenClaims{
+ Audience: []string{clientID, "other"},
+ AuthorizedParty: clientID,
+ },
+ },
+ {
+ name: "wrong azp",
+ claims: &TokenClaims{
+ AuthorizedParty: "wrong",
+ },
+ wantErr: ErrAzpInvalid,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckAuthorizedParty(tt.claims, clientID)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckExpiration(t *testing.T) {
+ const offset = time.Minute
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrExpired,
+ },
+ {
+ name: "expired",
+ claims: &TokenClaims{
+ Expiration: FromTime(time.Now().Add(-2 * offset)),
+ },
+ wantErr: ErrExpired,
+ },
+ {
+ name: "valid",
+ claims: &TokenClaims{
+ Expiration: FromTime(time.Now().Add(2 * offset)),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckExpiration(tt.claims, offset)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckIssuedAt(t *testing.T) {
+ const offset = time.Minute
+ tests := []struct {
+ name string
+ maxAgeIAT time.Duration
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrIatMissing,
+ },
+ {
+ name: "future",
+ claims: &TokenClaims{
+ IssuedAt: FromTime(time.Now().Add(time.Hour)),
+ },
+ wantErr: ErrIatInFuture,
+ },
+ {
+ name: "no max",
+ claims: &TokenClaims{
+ IssuedAt: FromTime(time.Now()),
+ },
+ },
+ {
+ name: "past max",
+ maxAgeIAT: time.Minute,
+ claims: &TokenClaims{
+ IssuedAt: FromTime(time.Now().Add(-time.Hour)),
+ },
+ wantErr: ErrIatToOld,
+ },
+ {
+ name: "within max",
+ maxAgeIAT: time.Hour,
+ claims: &TokenClaims{
+ IssuedAt: FromTime(time.Now()),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckIssuedAt(tt.claims, tt.maxAgeIAT, offset)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckNonce(t *testing.T) {
+ const nonce = "123"
+ tests := []struct {
+ name string
+ claims Claims
+ wantErr error
+ }{
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ wantErr: ErrNonceInvalid,
+ },
+ {
+ name: "wrong",
+ claims: &TokenClaims{
+ Nonce: "wrong",
+ },
+ wantErr: ErrNonceInvalid,
+ },
+ {
+ name: "ok",
+ claims: &TokenClaims{
+ Nonce: nonce,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckNonce(tt.claims, nonce)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckAuthorizationContextClassReference(t *testing.T) {
+ tests := []struct {
+ name string
+ acr ACRVerifier
+ wantErr error
+ }{
+ {
+ name: "error",
+ acr: func(s string) error { return errors.New("oops") },
+ wantErr: ErrAcrInvalid,
+ },
+ {
+ name: "ok",
+ acr: func(s string) error { return nil },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckAuthorizationContextClassReference(&IDTokenClaims{}, tt.acr)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
+
+func TestCheckAuthTime(t *testing.T) {
+ tests := []struct {
+ name string
+ claims Claims
+ maxAge time.Duration
+ wantErr error
+ }{
+ {
+ name: "no max age",
+ claims: &TokenClaims{},
+ },
+ {
+ name: "missing",
+ claims: &TokenClaims{},
+ maxAge: time.Minute,
+ wantErr: ErrAuthTimeNotPresent,
+ },
+ {
+ name: "expired",
+ maxAge: time.Minute,
+ claims: &TokenClaims{
+ AuthTime: FromTime(time.Now().Add(-time.Hour)),
+ },
+ wantErr: ErrAuthTimeToOld,
+ },
+ {
+ name: "ok",
+ maxAge: time.Minute,
+ claims: &TokenClaims{
+ AuthTime: NowTime(),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := CheckAuthTime(tt.claims, tt.maxAge)
+ assert.ErrorIs(t, err, tt.wantErr)
+ })
+ }
+}
diff --git a/pkg/op/applicationtype_enumer.go b/pkg/op/applicationtype_enumer.go
new file mode 100644
index 0000000..7f0b1e0
--- /dev/null
+++ b/pkg/op/applicationtype_enumer.go
@@ -0,0 +1,342 @@
+// Code generated by "enumer -linecomment -sql -json -text -yaml -gqlgen -type=ApplicationType,AccessTokenType"; DO NOT EDIT.
+
+package op
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+)
+
+const _ApplicationTypeName = "webuser_agentnative"
+
+var _ApplicationTypeIndex = [...]uint8{0, 3, 13, 19}
+
+const _ApplicationTypeLowerName = "webuser_agentnative"
+
+func (i ApplicationType) String() string {
+ if i < 0 || i >= ApplicationType(len(_ApplicationTypeIndex)-1) {
+ return fmt.Sprintf("ApplicationType(%d)", i)
+ }
+ return _ApplicationTypeName[_ApplicationTypeIndex[i]:_ApplicationTypeIndex[i+1]]
+}
+
+// An "invalid array index" compiler error signifies that the constant values have changed.
+// Re-run the stringer command to generate them again.
+func _ApplicationTypeNoOp() {
+ var x [1]struct{}
+ _ = x[ApplicationTypeWeb-(0)]
+ _ = x[ApplicationTypeUserAgent-(1)]
+ _ = x[ApplicationTypeNative-(2)]
+}
+
+var _ApplicationTypeValues = []ApplicationType{ApplicationTypeWeb, ApplicationTypeUserAgent, ApplicationTypeNative}
+
+var _ApplicationTypeNameToValueMap = map[string]ApplicationType{
+ _ApplicationTypeName[0:3]: ApplicationTypeWeb,
+ _ApplicationTypeLowerName[0:3]: ApplicationTypeWeb,
+ _ApplicationTypeName[3:13]: ApplicationTypeUserAgent,
+ _ApplicationTypeLowerName[3:13]: ApplicationTypeUserAgent,
+ _ApplicationTypeName[13:19]: ApplicationTypeNative,
+ _ApplicationTypeLowerName[13:19]: ApplicationTypeNative,
+}
+
+var _ApplicationTypeNames = []string{
+ _ApplicationTypeName[0:3],
+ _ApplicationTypeName[3:13],
+ _ApplicationTypeName[13:19],
+}
+
+// ApplicationTypeString retrieves an enum value from the enum constants string name.
+// Throws an error if the param is not part of the enum.
+func ApplicationTypeString(s string) (ApplicationType, error) {
+ if val, ok := _ApplicationTypeNameToValueMap[s]; ok {
+ return val, nil
+ }
+
+ if val, ok := _ApplicationTypeNameToValueMap[strings.ToLower(s)]; ok {
+ return val, nil
+ }
+ return 0, fmt.Errorf("%s does not belong to ApplicationType values", s)
+}
+
+// ApplicationTypeValues returns all values of the enum
+func ApplicationTypeValues() []ApplicationType {
+ return _ApplicationTypeValues
+}
+
+// ApplicationTypeStrings returns a slice of all String values of the enum
+func ApplicationTypeStrings() []string {
+ strs := make([]string, len(_ApplicationTypeNames))
+ copy(strs, _ApplicationTypeNames)
+ return strs
+}
+
+// IsAApplicationType returns "true" if the value is listed in the enum definition. "false" otherwise
+func (i ApplicationType) IsAApplicationType() bool {
+ for _, v := range _ApplicationTypeValues {
+ if i == v {
+ return true
+ }
+ }
+ return false
+}
+
+// MarshalJSON implements the json.Marshaler interface for ApplicationType
+func (i ApplicationType) MarshalJSON() ([]byte, error) {
+ return json.Marshal(i.String())
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface for ApplicationType
+func (i *ApplicationType) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return fmt.Errorf("ApplicationType should be a string, got %s", data)
+ }
+
+ var err error
+ *i, err = ApplicationTypeString(s)
+ return err
+}
+
+// MarshalText implements the encoding.TextMarshaler interface for ApplicationType
+func (i ApplicationType) MarshalText() ([]byte, error) {
+ return []byte(i.String()), nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface for ApplicationType
+func (i *ApplicationType) UnmarshalText(text []byte) error {
+ var err error
+ *i, err = ApplicationTypeString(string(text))
+ return err
+}
+
+// MarshalYAML implements a YAML Marshaler for ApplicationType
+func (i ApplicationType) MarshalYAML() (interface{}, error) {
+ return i.String(), nil
+}
+
+// UnmarshalYAML implements a YAML Unmarshaler for ApplicationType
+func (i *ApplicationType) UnmarshalYAML(unmarshal func(interface{}) error) error {
+ var s string
+ if err := unmarshal(&s); err != nil {
+ return err
+ }
+
+ var err error
+ *i, err = ApplicationTypeString(s)
+ return err
+}
+
+func (i ApplicationType) Value() (driver.Value, error) {
+ return i.String(), nil
+}
+
+func (i *ApplicationType) Scan(value interface{}) error {
+ if value == nil {
+ return nil
+ }
+
+ var str string
+ switch v := value.(type) {
+ case []byte:
+ str = string(v)
+ case string:
+ str = v
+ case fmt.Stringer:
+ str = v.String()
+ default:
+ return fmt.Errorf("invalid value of ApplicationType: %[1]T(%[1]v)", value)
+ }
+
+ val, err := ApplicationTypeString(str)
+ if err != nil {
+ return err
+ }
+
+ *i = val
+ return nil
+}
+
+// MarshalGQL implements the graphql.Marshaler interface for ApplicationType
+func (i ApplicationType) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(i.String()))
+}
+
+// UnmarshalGQL implements the graphql.Unmarshaler interface for ApplicationType
+func (i *ApplicationType) UnmarshalGQL(value interface{}) error {
+ str, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("ApplicationType should be a string, got %T", value)
+ }
+
+ var err error
+ *i, err = ApplicationTypeString(str)
+ return err
+}
+
+const _AccessTokenTypeName = "bearerJWT"
+
+var _AccessTokenTypeIndex = [...]uint8{0, 6, 9}
+
+const _AccessTokenTypeLowerName = "bearerjwt"
+
+func (i AccessTokenType) String() string {
+ if i < 0 || i >= AccessTokenType(len(_AccessTokenTypeIndex)-1) {
+ return fmt.Sprintf("AccessTokenType(%d)", i)
+ }
+ return _AccessTokenTypeName[_AccessTokenTypeIndex[i]:_AccessTokenTypeIndex[i+1]]
+}
+
+// An "invalid array index" compiler error signifies that the constant values have changed.
+// Re-run the stringer command to generate them again.
+func _AccessTokenTypeNoOp() {
+ var x [1]struct{}
+ _ = x[AccessTokenTypeBearer-(0)]
+ _ = x[AccessTokenTypeJWT-(1)]
+}
+
+var _AccessTokenTypeValues = []AccessTokenType{AccessTokenTypeBearer, AccessTokenTypeJWT}
+
+var _AccessTokenTypeNameToValueMap = map[string]AccessTokenType{
+ _AccessTokenTypeName[0:6]: AccessTokenTypeBearer,
+ _AccessTokenTypeLowerName[0:6]: AccessTokenTypeBearer,
+ _AccessTokenTypeName[6:9]: AccessTokenTypeJWT,
+ _AccessTokenTypeLowerName[6:9]: AccessTokenTypeJWT,
+}
+
+var _AccessTokenTypeNames = []string{
+ _AccessTokenTypeName[0:6],
+ _AccessTokenTypeName[6:9],
+}
+
+// AccessTokenTypeString retrieves an enum value from the enum constants string name.
+// Throws an error if the param is not part of the enum.
+func AccessTokenTypeString(s string) (AccessTokenType, error) {
+ if val, ok := _AccessTokenTypeNameToValueMap[s]; ok {
+ return val, nil
+ }
+
+ if val, ok := _AccessTokenTypeNameToValueMap[strings.ToLower(s)]; ok {
+ return val, nil
+ }
+ return 0, fmt.Errorf("%s does not belong to AccessTokenType values", s)
+}
+
+// AccessTokenTypeValues returns all values of the enum
+func AccessTokenTypeValues() []AccessTokenType {
+ return _AccessTokenTypeValues
+}
+
+// AccessTokenTypeStrings returns a slice of all String values of the enum
+func AccessTokenTypeStrings() []string {
+ strs := make([]string, len(_AccessTokenTypeNames))
+ copy(strs, _AccessTokenTypeNames)
+ return strs
+}
+
+// IsAAccessTokenType returns "true" if the value is listed in the enum definition. "false" otherwise
+func (i AccessTokenType) IsAAccessTokenType() bool {
+ for _, v := range _AccessTokenTypeValues {
+ if i == v {
+ return true
+ }
+ }
+ return false
+}
+
+// MarshalJSON implements the json.Marshaler interface for AccessTokenType
+func (i AccessTokenType) MarshalJSON() ([]byte, error) {
+ return json.Marshal(i.String())
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface for AccessTokenType
+func (i *AccessTokenType) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return fmt.Errorf("AccessTokenType should be a string, got %s", data)
+ }
+
+ var err error
+ *i, err = AccessTokenTypeString(s)
+ return err
+}
+
+// MarshalText implements the encoding.TextMarshaler interface for AccessTokenType
+func (i AccessTokenType) MarshalText() ([]byte, error) {
+ return []byte(i.String()), nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface for AccessTokenType
+func (i *AccessTokenType) UnmarshalText(text []byte) error {
+ var err error
+ *i, err = AccessTokenTypeString(string(text))
+ return err
+}
+
+// MarshalYAML implements a YAML Marshaler for AccessTokenType
+func (i AccessTokenType) MarshalYAML() (interface{}, error) {
+ return i.String(), nil
+}
+
+// UnmarshalYAML implements a YAML Unmarshaler for AccessTokenType
+func (i *AccessTokenType) UnmarshalYAML(unmarshal func(interface{}) error) error {
+ var s string
+ if err := unmarshal(&s); err != nil {
+ return err
+ }
+
+ var err error
+ *i, err = AccessTokenTypeString(s)
+ return err
+}
+
+func (i AccessTokenType) Value() (driver.Value, error) {
+ return i.String(), nil
+}
+
+func (i *AccessTokenType) Scan(value interface{}) error {
+ if value == nil {
+ return nil
+ }
+
+ var str string
+ switch v := value.(type) {
+ case []byte:
+ str = string(v)
+ case string:
+ str = v
+ case fmt.Stringer:
+ str = v.String()
+ default:
+ return fmt.Errorf("invalid value of AccessTokenType: %[1]T(%[1]v)", value)
+ }
+
+ val, err := AccessTokenTypeString(str)
+ if err != nil {
+ return err
+ }
+
+ *i = val
+ return nil
+}
+
+// MarshalGQL implements the graphql.Marshaler interface for AccessTokenType
+func (i AccessTokenType) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(i.String()))
+}
+
+// UnmarshalGQL implements the graphql.Unmarshaler interface for AccessTokenType
+func (i *AccessTokenType) UnmarshalGQL(value interface{}) error {
+ str, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("AccessTokenType should be a string, got %T", value)
+ }
+
+ var err error
+ *i, err = AccessTokenTypeString(str)
+ return err
+}
diff --git a/pkg/op/auth_request.go b/pkg/op/auth_request.go
new file mode 100644
index 0000000..b1434cc
--- /dev/null
+++ b/pkg/op/auth_request.go
@@ -0,0 +1,680 @@
+package op
+
+import (
+ "bytes"
+ "context"
+ _ "embed"
+ "errors"
+ "fmt"
+ "html/template"
+ "log/slog"
+ "net"
+ "net/http"
+ "net/url"
+ "slices"
+ "strings"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/bmatcuk/doublestar/v4"
+)
+
+type AuthRequest interface {
+ GetID() string
+ GetACR() string
+ GetAMR() []string
+ GetAudience() []string
+ GetAuthTime() time.Time
+ GetClientID() string
+ GetCodeChallenge() *oidc.CodeChallenge
+ GetNonce() string
+ GetRedirectURI() string
+ GetResponseType() oidc.ResponseType
+ GetResponseMode() oidc.ResponseMode
+ GetScopes() []string
+ GetState() string
+ GetSubject() string
+ Done() bool
+}
+
+// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported
+type AuthRequestSessionState interface {
+ // GetSessionState returns session_state.
+ // session_state is related to OpenID Connect Session Management.
+ GetSessionState() string
+}
+
+type Authorizer interface {
+ Storage() Storage
+ Decoder() httphelper.Decoder
+ Encoder() httphelper.Encoder
+ IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
+ Crypto() Crypto
+ RequestObjectSupported() bool
+ Logger() *slog.Logger
+}
+
+// AuthorizeValidator is an extension of Authorizer interface
+// implementing its own validation mechanism for the auth request
+type AuthorizeValidator interface {
+ Authorizer
+ ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, *IDTokenHintVerifier) (string, error)
+}
+
+type CodeResponseType struct {
+ Code string `schema:"code"`
+ State string `schema:"state,omitempty"`
+ SessionState string `schema:"session_state,omitempty"`
+}
+
+func authorizeHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Authorize(w, r, authorizer)
+ }
+}
+
+func AuthorizeCallbackHandler(authorizer Authorizer) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ AuthorizeCallback(w, r, authorizer)
+ }
+}
+
+// Authorize handles the authorization request, including
+// parsing, validating, storing and finally redirecting to the login handler
+func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
+ ctx, span := tracer.Start(r.Context(), "Authorize")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
+ if err != nil {
+ AuthRequestError(w, r, nil, err, authorizer)
+ return
+ }
+ if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
+ err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
+ if err != nil {
+ AuthRequestError(w, r, nil, err, authorizer)
+ return
+ }
+ }
+ if authReq.ClientID == "" {
+ AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing client_id"), authorizer)
+ return
+ }
+ if authReq.RedirectURI == "" {
+ AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing redirect_uri"), authorizer)
+ return
+ }
+
+ var client Client
+ validation := func(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
+ client, err = authorizer.Storage().GetClientByClientID(ctx, authReq.ClientID)
+ if err != nil {
+ return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
+ }
+ return ValidateAuthRequestClient(ctx, authReq, client, verifier)
+ }
+ if validator, ok := authorizer.(AuthorizeValidator); ok {
+ validation = validator.ValidateAuthRequest
+ }
+ userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ return
+ }
+ if authReq.RequestParam != "" {
+ AuthRequestError(w, r, authReq, oidc.ErrRequestNotSupported(), authorizer)
+ return
+ }
+ req, err := authorizer.Storage().CreateAuthRequest(ctx, authReq, userID)
+ if err != nil {
+ AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer)
+ return
+ }
+ RedirectToLogin(req.GetID(), client, w, r)
+}
+
+// ParseAuthorizeRequest parsed the http request into an oidc.AuthRequest
+func ParseAuthorizeRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AuthRequest, error) {
+ err := r.ParseForm()
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
+ }
+ authReq := new(oidc.AuthRequest)
+ err = decoder.Decode(authReq, r.Form)
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse auth request").WithParent(err)
+ }
+ return authReq, nil
+}
+
+// ParseRequestObject parse the `request` parameter, validates the token including the signature
+// and copies the token claims into the auth request
+func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, issuer string) error {
+ requestObject := new(oidc.RequestObject)
+ payload, err := oidc.ParseToken(authReq.RequestParam, requestObject)
+ if err != nil {
+ return err
+ }
+
+ if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
+ return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request")
+ }
+ if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
+ return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request")
+ }
+ if requestObject.Issuer != requestObject.ClientID {
+ return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request")
+ }
+ if !slices.Contains(requestObject.Audience, issuer) {
+ return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience")
+ }
+ keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
+ if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
+ return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error())
+ }
+ CopyRequestObjectToAuthRequest(authReq, requestObject)
+ return nil
+}
+
+// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
+// and clears the `RequestParam` of the auth request
+func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) {
+ if slices.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 {
+ authReq.Scopes = requestObject.Scopes
+ }
+ if requestObject.RedirectURI != "" {
+ authReq.RedirectURI = requestObject.RedirectURI
+ }
+ if requestObject.State != "" {
+ authReq.State = requestObject.State
+ }
+ if requestObject.ResponseMode != "" {
+ authReq.ResponseMode = requestObject.ResponseMode
+ }
+ if requestObject.Nonce != "" {
+ authReq.Nonce = requestObject.Nonce
+ }
+ if requestObject.Display != "" {
+ authReq.Display = requestObject.Display
+ }
+ if len(requestObject.Prompt) > 0 {
+ authReq.Prompt = requestObject.Prompt
+ }
+ if requestObject.MaxAge != nil {
+ authReq.MaxAge = requestObject.MaxAge
+ }
+ if len(requestObject.UILocales) > 0 {
+ authReq.UILocales = requestObject.UILocales
+ }
+ if requestObject.IDTokenHint != "" {
+ authReq.IDTokenHint = requestObject.IDTokenHint
+ }
+ if requestObject.LoginHint != "" {
+ authReq.LoginHint = requestObject.LoginHint
+ }
+ if len(requestObject.ACRValues) > 0 {
+ authReq.ACRValues = requestObject.ACRValues
+ }
+ if requestObject.CodeChallenge != "" {
+ authReq.CodeChallenge = requestObject.CodeChallenge
+ }
+ if requestObject.CodeChallengeMethod != "" {
+ authReq.CodeChallengeMethod = requestObject.CodeChallengeMethod
+ }
+ authReq.RequestParam = ""
+}
+
+// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed.
+//
+// Deprecated: Use [ValidateAuthRequestClient] to prevent querying for the Client twice.
+func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier *IDTokenHintVerifier) (sub string, err error) {
+ ctx, span := tracer.Start(ctx, "ValidateAuthRequest")
+ defer span.End()
+
+ client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
+ if err != nil {
+ return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
+ }
+ return ValidateAuthRequestClient(ctx, authReq, client, verifier)
+}
+
+// ValidateAuthRequestClient validates the Auth request against the passed client.
+// If id_token_hint is part of the request, the subject of the token is returned.
+func ValidateAuthRequestClient(ctx context.Context, authReq *oidc.AuthRequest, client Client, verifier *IDTokenHintVerifier) (sub string, err error) {
+ ctx, span := tracer.Start(ctx, "ValidateAuthRequestClient")
+ defer span.End()
+
+ if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
+ return "", err
+ }
+ authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
+ if err != nil {
+ return "", err
+ }
+ authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
+ if err != nil {
+ return "", err
+ }
+ if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil {
+ return "", err
+ }
+ return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier)
+}
+
+// ValidateAuthReqPrompt validates the passed prompt values and sets max_age to 0 if prompt login is present
+func ValidateAuthReqPrompt(prompts []string, maxAge *uint) (_ *uint, err error) {
+ for _, prompt := range prompts {
+ if prompt == oidc.PromptNone && len(prompts) > 1 {
+ return nil, oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value")
+ }
+ if prompt == oidc.PromptLogin {
+ maxAge = oidc.NewMaxAge(0)
+ }
+ }
+ return maxAge, nil
+}
+
+// ValidateAuthReqScopes validates the passed scopes and deletes any unsupported scopes.
+// An error is returned if scopes is empty.
+func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
+ if len(scopes) == 0 {
+ return nil, oidc.ErrInvalidRequest().
+ WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ scopes = slices.DeleteFunc(scopes, func(scope string) bool {
+ return !(scope == oidc.ScopeOpenID ||
+ scope == oidc.ScopeProfile ||
+ scope == oidc.ScopeEmail ||
+ scope == oidc.ScopePhone ||
+ scope == oidc.ScopeAddress ||
+ scope == oidc.ScopeOfflineAccess) &&
+ !client.IsScopeAllowed(scope)
+ })
+ return scopes, nil
+}
+
+// checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores
+// other factors.
+func checkURIAgainstRedirects(client Client, uri string) error {
+ if slices.Contains(client.RedirectURIs(), uri) {
+ return nil
+ }
+ if globClient, ok := client.(HasRedirectGlobs); ok {
+ for _, uriGlob := range globClient.RedirectURIGlobs() {
+ isMatch, err := doublestar.Match(uriGlob, uri)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err)
+ }
+ if isMatch {
+ return nil
+ }
+ }
+ }
+ return oidc.ErrInvalidRequestRedirectURI().
+ WithDescription("The requested redirect_uri is missing in the client configuration. " +
+ "If you have any questions, you may contact the administrator of the application.")
+}
+
+// ValidateAuthReqRedirectURI validates the passed redirect_uri and response_type to the registered uris and client type
+func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.ResponseType) error {
+ if uri == "" {
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " +
+ "Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
+ }
+ if client.ApplicationType() == ApplicationTypeNative {
+ return validateAuthReqRedirectURINative(client, uri)
+ }
+ if strings.HasPrefix(uri, "https://") {
+ return checkURIAgainstRedirects(client, uri)
+ }
+ if err := checkURIAgainstRedirects(client, uri); err != nil {
+ return err
+ }
+ if strings.HasPrefix(uri, "http://") {
+ if client.DevMode() {
+ return nil
+ }
+ if responseType == oidc.ResponseTypeCode && IsConfidentialType(client) {
+ return nil
+ }
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is using a custom schema and is not allowed. " +
+ "If you have any questions, you may contact the administrator of the application.")
+}
+
+// ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
+func validateAuthReqRedirectURINative(client Client, uri string) error {
+ parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
+ isCustomSchema := !(strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://"))
+ if err := checkURIAgainstRedirects(client, uri); err == nil {
+ if client.DevMode() {
+ return nil
+ }
+ if !isLoopback && strings.HasPrefix(uri, "https://") {
+ return nil
+ }
+ // The RedirectURIs are only valid for native clients when localhost or non-"http://" and "https://"
+ if isLoopback || isCustomSchema {
+ return nil
+ }
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("This client's redirect_uri is http and is not allowed. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ if !isLoopback {
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ for _, uri := range client.RedirectURIs() {
+ redirectURI, ok := HTTPLoopbackOrLocalhost(uri)
+ if ok && equalURI(parsedURL, redirectURI) {
+ return nil
+ }
+ }
+ return oidc.ErrInvalidRequestRedirectURI().WithDescription("The requested redirect_uri is missing in the client configuration." +
+ " If you have any questions, you may contact the administrator of the application.")
+}
+
+func equalURI(url1, url2 *url.URL) bool {
+ return url1.Path == url2.Path && url1.RawQuery == url2.RawQuery
+}
+
+func HTTPLoopbackOrLocalhost(rawURL string) (*url.URL, bool) {
+ parsedURL, err := url.Parse(rawURL)
+ if err != nil {
+ return nil, false
+ }
+ if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
+ hostName := parsedURL.Hostname()
+ return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback()
+ }
+ return nil, false
+}
+
+// ValidateAuthReqResponseType validates the passed response_type to the registered response types
+func ValidateAuthReqResponseType(client Client, responseType oidc.ResponseType) error {
+ if responseType == "" {
+ return oidc.ErrInvalidRequest().WithDescription("The response type is missing in your request. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ if !ContainsResponseType(client.ResponseTypes(), responseType) {
+ return oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
+ "If you have any questions, you may contact the administrator of the application.")
+ }
+ return nil
+}
+
+// ValidateAuthReqIDTokenHint validates the id_token_hint (if passed as parameter in the request)
+// and returns the `sub` claim
+func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier *IDTokenHintVerifier) (string, error) {
+ if idTokenHint == "" {
+ return "", nil
+ }
+ claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
+ if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
+ return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
+ "If you have any questions, you may contact the administrator of the application.").WithParent(err)
+ }
+ return claims.GetSubject(), nil
+}
+
+// RedirectToLogin redirects the end user to the Login UI for authentication
+func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
+ login := client.LoginURL(authReqID)
+ http.Redirect(w, r, login, http.StatusFound)
+}
+
+// AuthorizeCallback handles the callback after authentication in the Login UI
+func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
+ ctx, span := tracer.Start(r.Context(), "AuthorizeCallback")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ id, err := ParseAuthorizeCallbackRequest(r)
+ if err != nil {
+ AuthRequestError(w, r, nil, err, authorizer)
+ return
+ }
+ authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
+ if err != nil {
+ AuthRequestError(w, r, nil, err, authorizer)
+ return
+ }
+ if !authReq.Done() {
+ AuthRequestError(w, r, authReq,
+ oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required."),
+ authorizer)
+ return
+ }
+ AuthResponse(authReq, authorizer, w, r)
+}
+
+func ParseAuthorizeCallbackRequest(r *http.Request) (id string, err error) {
+ if err = r.ParseForm(); err != nil {
+ return "", fmt.Errorf("cannot parse form: %w", err)
+ }
+ id = r.Form.Get("id")
+ if id == "" {
+ return "", errors.New("auth request callback is missing id")
+ }
+ return id, nil
+}
+
+// AuthResponse creates the successful authentication response (either code or tokens)
+func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
+ ctx, span := tracer.Start(r.Context(), "AuthResponse")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ return
+ }
+ if authReq.GetResponseType() == oidc.ResponseTypeCode {
+ AuthResponseCode(w, r, authReq, authorizer)
+ return
+ }
+ AuthResponseToken(w, r, authReq, authorizer, client)
+}
+
+// AuthResponseCode handles the creation of a successful authentication response using an authorization code
+func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
+ ctx, span := tracer.Start(r.Context(), "AuthResponseCode")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ var err error
+ if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
+ err = handleFormPostResponse(w, r, authReq, authorizer)
+ } else {
+ err = handleRedirectResponse(w, r, authReq, authorizer)
+ }
+
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ }
+}
+
+// handleFormPostResponse processes the authentication response using form post method
+func handleFormPostResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
+ codeResponse, err := BuildAuthResponseCodeResponsePayload(r.Context(), authReq, authorizer)
+ if err != nil {
+ return err
+ }
+ return AuthResponseFormPost(w, authReq.GetRedirectURI(), codeResponse, authorizer.Encoder())
+}
+
+// handleRedirectResponse processes the authentication response using the redirect method
+func handleRedirectResponse(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) error {
+ callbackURL, err := BuildAuthResponseCallbackURL(r.Context(), authReq, authorizer)
+ if err != nil {
+ return err
+ }
+ http.Redirect(w, r, callbackURL, http.StatusFound)
+ return nil
+}
+
+// BuildAuthResponseCodeResponsePayload generates the authorization code response payload for the authentication request
+func BuildAuthResponseCodeResponsePayload(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (*CodeResponseType, error) {
+ code, err := CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
+ if err != nil {
+ return nil, err
+ }
+
+ sessionState := ""
+ if authRequestSessionState, ok := authReq.(AuthRequestSessionState); ok {
+ sessionState = authRequestSessionState.GetSessionState()
+ }
+
+ return &CodeResponseType{
+ Code: code,
+ State: authReq.GetState(),
+ SessionState: sessionState,
+ }, nil
+}
+
+// BuildAuthResponseCallbackURL generates the callback URL for a successful authorization code response
+func BuildAuthResponseCallbackURL(ctx context.Context, authReq AuthRequest, authorizer Authorizer) (string, error) {
+ codeResponse, err := BuildAuthResponseCodeResponsePayload(ctx, authReq, authorizer)
+ if err != nil {
+ return "", err
+ }
+
+ return AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), codeResponse, authorizer.Encoder())
+}
+
+// AuthResponseToken creates the successful token(s) authentication response
+func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
+ ctx, span := tracer.Start(r.Context(), "AuthResponseToken")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
+ resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "", "")
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ return
+ }
+
+ if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
+ err := AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder())
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ return
+ }
+
+ return
+ }
+
+ callback, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
+ if err != nil {
+ AuthRequestError(w, r, authReq, err, authorizer)
+ return
+ }
+ http.Redirect(w, r, callback, http.StatusFound)
+}
+
+// CreateAuthRequestCode creates and stores a code for the auth code response
+func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
+ ctx, span := tracer.Start(ctx, "CreateAuthRequestCode")
+ defer span.End()
+
+ code, err := BuildAuthRequestCode(authReq, crypto)
+ if err != nil {
+ return "", err
+ }
+ if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
+ return "", err
+ }
+ return code, nil
+}
+
+// BuildAuthRequestCode builds the string representation of the auth code
+func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
+ return crypto.Encrypt(authReq.GetID())
+}
+
+// AuthResponseURL encodes the authorization response (successful and error) and sets it as query or fragment values
+// depending on the response_mode and response_type
+func AuthResponseURL(redirectURI string, responseType oidc.ResponseType, responseMode oidc.ResponseMode, response any, encoder httphelper.Encoder) (string, error) {
+ uri, err := url.Parse(redirectURI)
+ if err != nil {
+ return "", oidc.ErrServerError().WithParent(err)
+ }
+ params, err := httphelper.URLEncodeParams(response, encoder)
+ if err != nil {
+ return "", oidc.ErrServerError().WithParent(err)
+ }
+ // return explicitly requested mode
+ if responseMode == oidc.ResponseModeQuery {
+ return mergeQueryParams(uri, params), nil
+ }
+ if responseMode == oidc.ResponseModeFragment {
+ return setFragment(uri, params), nil
+ }
+ // implicit must use fragment mode is not specified by client
+ if responseType == oidc.ResponseTypeIDToken || responseType == oidc.ResponseTypeIDTokenOnly {
+ return setFragment(uri, params), nil
+ }
+ // if we get here it's code flow: defaults to query
+ return mergeQueryParams(uri, params), nil
+}
+
+//go:embed form_post.html.tmpl
+var formPostHtmlTemplate string
+
+var formPostTmpl = template.Must(template.New("form_post").Parse(formPostHtmlTemplate))
+
+// AuthResponseFormPost responds a html page that automatically submits the form which contains the auth response parameters
+func AuthResponseFormPost(res http.ResponseWriter, redirectURI string, response any, encoder httphelper.Encoder) error {
+ values := make(map[string][]string)
+ err := encoder.Encode(response, values)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err)
+ }
+
+ params := &struct {
+ RedirectURI string
+ Params any
+ }{
+ RedirectURI: redirectURI,
+ Params: values,
+ }
+
+ var buf bytes.Buffer
+ err = formPostTmpl.Execute(&buf, params)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err)
+ }
+
+ res.Header().Set("Cache-Control", "no-store")
+ res.WriteHeader(http.StatusOK)
+ _, err = buf.WriteTo(res)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err)
+ }
+
+ return nil
+}
+
+func setFragment(uri *url.URL, params url.Values) string {
+ uri.Fragment = params.Encode()
+ return uri.String()
+}
+
+func mergeQueryParams(uri *url.URL, params url.Values) string {
+ queries := uri.Query()
+ for param, values := range params {
+ for _, value := range values {
+ queries.Add(param, value)
+ }
+ }
+ uri.RawQuery = queries.Encode()
+ return uri.String()
+}
diff --git a/pkg/op/auth_request_test.go b/pkg/op/auth_request_test.go
new file mode 100644
index 0000000..d1ea965
--- /dev/null
+++ b/pkg/op/auth_request_test.go
@@ -0,0 +1,1612 @@
+package op_test
+
+import (
+ "context"
+ "errors"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "reflect"
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/schema"
+)
+
+func TestAuthorize(t *testing.T) {
+ tests := []struct {
+ name string
+ req *http.Request
+ expect func(a *mock.MockAuthorizerMockRecorder)
+ }{
+ {
+ name: "parse error", // used to panic, see issue #315
+ req: httptest.NewRequest(http.MethodPost, "/?;", nil),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ authorizer := mock.NewMockAuthorizer(gomock.NewController(t))
+
+ expect := authorizer.EXPECT()
+ expect.Decoder().Return(schema.NewDecoder())
+ expect.Logger().Return(slog.Default())
+
+ if tt.expect != nil {
+ tt.expect(expect)
+ }
+
+ op.Authorize(w, tt.req, authorizer)
+ })
+ }
+}
+
+func TestParseAuthorizeRequest(t *testing.T) {
+ type args struct {
+ r *http.Request
+ decoder httphelper.Decoder
+ }
+ type res struct {
+ want *oidc.AuthRequest
+ err bool
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "parsing form error",
+ args{
+ &http.Request{URL: &url.URL{RawQuery: "invalid=%%param"}},
+ schema.NewDecoder(),
+ },
+ res{
+ nil,
+ true,
+ },
+ },
+ {
+ "decoding error",
+ args{
+ &http.Request{URL: &url.URL{RawQuery: "unknown=value"}},
+ func() httphelper.Decoder {
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(false)
+ return decoder
+ }(),
+ },
+ res{
+ nil,
+ true,
+ },
+ },
+ {
+ "parsing ok",
+ args{
+ &http.Request{URL: &url.URL{RawQuery: "scope=openid"}},
+ func() httphelper.Decoder {
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(false)
+ return decoder
+ }(),
+ },
+ res{
+ &oidc.AuthRequest{Scopes: oidc.SpaceDelimitedArray{"openid"}},
+ false,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.ParseAuthorizeRequest(tt.args.r, tt.args.decoder)
+ if (err != nil) != tt.res.err {
+ t.Errorf("ParseAuthorizeRequest() error = %v, wantErr %v", err, tt.res.err)
+ }
+ if !reflect.DeepEqual(got, tt.res.want) {
+ t.Errorf("ParseAuthorizeRequest() got = %v, want %v", got, tt.res.want)
+ }
+ })
+ }
+}
+
+func TestValidateAuthRequest(t *testing.T) {
+ type args struct {
+ authRequest *oidc.AuthRequest
+ storage op.Storage
+ verifier *op.IDTokenHintVerifier
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr error
+ }{
+ {
+ "scope missing fails",
+ args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil},
+ oidc.ErrInvalidRequest(),
+ },
+ {
+ "response_type missing fails",
+ args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
+ oidc.ErrInvalidRequest(),
+ },
+ {
+ "client_id missing fails",
+ args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
+ oidc.ErrInvalidRequest(),
+ },
+ {
+ "redirect_uri missing fails",
+ args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil},
+ oidc.ErrInvalidRequest(),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := op.ValidateAuthRequest(context.TODO(), tt.args.authRequest, tt.args.storage, tt.args.verifier)
+ if tt.wantErr == nil && err != nil {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
+ }
+ if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestValidateAuthReqPrompt(t *testing.T) {
+ type args struct {
+ prompts []string
+ maxAge *uint
+ }
+ type res struct {
+ maxAge *uint
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "no prompts and maxAge, ok",
+ args{
+ nil,
+ nil,
+ },
+ res{
+ nil,
+ nil,
+ },
+ },
+ {
+ "no prompts but maxAge, ok",
+ args{
+ nil,
+ oidc.NewMaxAge(10),
+ },
+ res{
+ oidc.NewMaxAge(10),
+ nil,
+ },
+ },
+ {
+ "prompt none, ok",
+ args{
+ []string{"none"},
+ oidc.NewMaxAge(10),
+ },
+ res{
+ oidc.NewMaxAge(10),
+ nil,
+ },
+ },
+ {
+ "prompt none with others, err",
+ args{
+ []string{"none", "login"},
+ oidc.NewMaxAge(10),
+ },
+ res{
+ nil,
+ oidc.ErrInvalidRequest(),
+ },
+ },
+ {
+ "prompt login, ok",
+ args{
+ []string{"login"},
+ nil,
+ },
+ res{
+ oidc.NewMaxAge(0),
+ nil,
+ },
+ },
+ {
+ "prompt login with maxAge, ok",
+ args{
+ []string{"login"},
+ oidc.NewMaxAge(10),
+ },
+ res{
+ oidc.NewMaxAge(0),
+ nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ maxAge, err := op.ValidateAuthReqPrompt(tt.args.prompts, tt.args.maxAge)
+ if tt.res.err == nil && err != nil {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
+ }
+ if tt.res.err != nil && !errors.Is(err, tt.res.err) {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
+ }
+ assert.Equal(t, tt.res.maxAge, maxAge)
+ })
+ }
+}
+
+func TestValidateAuthReqScopes(t *testing.T) {
+ type args struct {
+ client op.Client
+ scopes []string
+ }
+ type res struct {
+ err bool
+ scopes []string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "scopes missing fails",
+ args{},
+ res{
+ err: true,
+ },
+ },
+ {
+ "scope ok",
+ args{
+ mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
+ []string{"openid"},
+ },
+ res{
+ scopes: []string{"openid"},
+ },
+ },
+ {
+ "scope with drop ok",
+ args{
+ mock.NewClientExpectAny(t, op.ApplicationTypeWeb),
+ []string{"openid", "email", "unknown"},
+ },
+ res{
+ scopes: []string{"openid", "email"},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ scopes, err := op.ValidateAuthReqScopes(tt.args.client, tt.args.scopes)
+ if (err != nil) != tt.res.err {
+ t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.res.err)
+ }
+ assert.ElementsMatch(t, scopes, tt.res.scopes)
+ })
+ }
+}
+
+func TestValidateAuthReqRedirectURI(t *testing.T) {
+ type args struct {
+ uri string
+ client op.Client
+ responseType oidc.ResponseType
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ "empty fails",
+ args{
+ "",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "unregistered https fails",
+ args{
+ "https://unregistered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "unregistered http fails",
+ args{
+ "http://unregistered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered https web ok",
+ args{
+ "https://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered https native ok",
+ args{
+ "https://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered https user agent ok",
+ args{
+ "https://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered http confidential (web) ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered http not confidential (native) fails",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered http not confidential (user agent) fails",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered http localhost native ok",
+ args{
+ "http://localhost:4200/callback",
+ mock.NewClientWithConfig(t, []string{"http://localhost/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered http loopback v4 native ok",
+ args{
+ "http://127.0.0.1:4200/callback",
+ mock.NewClientWithConfig(t, []string{"http://127.0.0.1/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered http loopback v6 native ok",
+ args{
+ "http://[::1]:4200/callback",
+ mock.NewClientWithConfig(t, []string{"http://[::1]/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered https loopback v4 native ok",
+ args{
+ "https://127.0.0.1:4200/callback",
+ mock.NewClientWithConfig(t, []string{"https://127.0.0.1/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow registered https loopback v6 native ok",
+ args{
+ "https://[::1]:4200/callback",
+ mock.NewClientWithConfig(t, []string{"https://[::1]/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow unregistered http native fails",
+ args{
+ "http://unregistered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://locahost/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow unregistered custom native fails",
+ args{
+ "unregistered://callback",
+ mock.NewClientWithConfig(t, []string{"registered://callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow unregistered loopback native fails",
+ args{
+ "http://[::1]:4200/unregistered",
+ mock.NewClientWithConfig(t, []string{"http://[::1]:4200/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered custom not native (web) fails",
+ args{
+ "custom://callback",
+ mock.NewClientWithConfig(t, []string{"custom://callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered custom not native (user agent) fails",
+ args{
+ "custom://callback",
+ mock.NewClientWithConfig(t, []string{"custom://callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ {
+ "code flow registered custom native ok",
+ args{
+ "custom://callback",
+ mock.NewClientWithConfig(t, []string{"custom://callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode http ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "implicit flow registered ok",
+ args{
+ "https://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ false,
+ },
+ {
+ "implicit flow unregistered fails",
+ args{
+ "https://unregistered.com/callback",
+ mock.NewClientWithConfig(t, []string{"https://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ true,
+ },
+ {
+ "implicit flow registered http localhost native ok",
+ args{
+ "http://localhost:9999/callback",
+ mock.NewClientWithConfig(t, []string{"http://localhost:9999/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ false,
+ },
+ {
+ "implicit flow registered http localhost web fails",
+ args{
+ "http://localhost:9999/callback",
+ mock.NewClientWithConfig(t, []string{"http://localhost:9999/callback"}, op.ApplicationTypeWeb, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ true,
+ },
+ {
+ "implicit flow registered http localhost user agent fails",
+ args{
+ "http://localhost:9999/callback",
+ mock.NewClientWithConfig(t, []string{"http://localhost:9999/callback"}, op.ApplicationTypeUserAgent, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ true,
+ },
+ {
+ "implicit flow http non localhost fails",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ true,
+ },
+ {
+ "implicit flow custom fails",
+ args{
+ "custom://callback",
+ mock.NewClientWithConfig(t, []string{"custom://callback"}, op.ApplicationTypeNative, nil, false),
+ oidc.ResponseTypeIDToken,
+ },
+ false,
+ },
+ {
+ "implicit flow dev mode http ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewClientWithConfig(t, []string{"http://registered.com/callback"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeIDToken,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs regular ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs wildcard ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs double star ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs double star ok",
+ args{
+ "http://registered.com/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs IPv6 ok",
+ args{
+ "http://[::1]:80/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://\\[::1\\]:80/*"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ false,
+ },
+ {
+ "code flow dev mode has redirect globs bad pattern",
+ args{
+ "http://registered.com/callback",
+ mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/\\"}, op.ApplicationTypeUserAgent, nil, true),
+ oidc.ResponseTypeCode,
+ },
+ true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := op.ValidateAuthReqRedirectURI(tt.args.client, tt.args.uri, tt.args.responseType); (err != nil) != tt.wantErr {
+ t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestLoopbackOrLocalhost(t *testing.T) {
+ type args struct {
+ url string
+ }
+ tests := []struct {
+ name string
+ args args
+ want bool
+ }{
+ {
+ "not parsable, false",
+ args{url: string('\n')},
+ false,
+ },
+ {
+ "not http, false",
+ args{url: "localhost/test"},
+ false,
+ },
+ {
+ "not http, false",
+ args{url: "http://localhost.com/test"},
+ false,
+ },
+ {
+ "v4 no port ok",
+ args{url: "http://127.0.0.1/test"},
+ true,
+ },
+ {
+ "v6 short no port ok",
+ args{url: "http://[::1]/test"},
+ true,
+ },
+ {
+ "v6 long no port ok",
+ args{url: "http://[0:0:0:0:0:0:0:1]/test"},
+ true,
+ },
+ {
+ "locahost no port ok",
+ args{url: "http://localhost/test"},
+ true,
+ },
+ {
+ "v4 with port ok",
+ args{url: "http://127.0.0.1:4200/test"},
+ true,
+ },
+ {
+ "v6 short with port ok",
+ args{url: "http://[::1]:4200/test"},
+ true,
+ },
+ {
+ "v6 long with port ok",
+ args{url: "http://[0:0:0:0:0:0:0:1]:4200/test"},
+ true,
+ },
+ {
+ "localhost with port ok",
+ args{url: "http://localhost:4200/test"},
+ true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if _, got := op.HTTPLoopbackOrLocalhost(tt.args.url); got != tt.want {
+ t.Errorf("loopbackOrLocalhost() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestValidateAuthReqResponseType(t *testing.T) {
+ type args struct {
+ responseType oidc.ResponseType
+ client op.Client
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ "empty response type",
+ args{
+ "",
+ mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true),
+ },
+ true,
+ },
+ {
+ "response type missing in client config",
+ args{
+ oidc.ResponseTypeIDToken,
+ mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true),
+ },
+ true,
+ },
+ {
+ "valid response type",
+ args{
+ oidc.ResponseTypeCode,
+ mock.NewClientWithConfig(t, nil, op.ApplicationTypeNative, []oidc.ResponseType{oidc.ResponseTypeCode}, true),
+ },
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if err := op.ValidateAuthReqResponseType(tt.args.client, tt.args.responseType); (err != nil) != tt.wantErr {
+ t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestRedirectToLogin(t *testing.T) {
+ type args struct {
+ authReqID string
+ client op.Client
+ w http.ResponseWriter
+ r *http.Request
+ }
+ tests := []struct {
+ name string
+ args args
+ }{
+ {
+ "redirect ok",
+ args{
+ "id",
+ mock.NewClientExpectAny(t, op.ApplicationTypeNative),
+ httptest.NewRecorder(),
+ httptest.NewRequest("GET", "/authorize", nil),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
+ rec := tt.args.w.(*httptest.ResponseRecorder)
+ require.Equal(t, http.StatusFound, rec.Code)
+ require.Equal(t, "/login?id=id", rec.Header().Get("location"))
+ })
+ }
+}
+
+func TestAuthResponseURL(t *testing.T) {
+ type args struct {
+ redirectURI string
+ responseType oidc.ResponseType
+ responseMode oidc.ResponseMode
+ response any
+ encoder httphelper.Encoder
+ }
+ type res struct {
+ url string
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "encoding error",
+ args{
+ "uri",
+ oidc.ResponseTypeCode,
+ "",
+ map[string]any{"test": "test"},
+ &mockEncoder{
+ errors.New("error encoding"),
+ },
+ },
+ res{
+ "",
+ oidc.ErrServerError(),
+ },
+ },
+ {
+ "response mode query",
+ args{
+ "uri",
+ oidc.ResponseTypeIDToken,
+ oidc.ResponseModeQuery,
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=test",
+ nil,
+ },
+ },
+ {
+ "response mode fragment",
+ args{
+ "uri",
+ oidc.ResponseTypeCode,
+ oidc.ResponseModeFragment,
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri#test=test",
+ nil,
+ },
+ },
+ {
+ "response type code",
+ args{
+ "uri",
+ oidc.ResponseTypeCode,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=test",
+ nil,
+ },
+ },
+ {
+ "response type id token",
+ args{
+ "uri",
+ oidc.ResponseTypeIDToken,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri#test=test",
+ nil,
+ },
+ },
+ {
+ "with query",
+ args{
+ "uri?param=value",
+ oidc.ResponseTypeCode,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?param=value&test=test",
+ nil,
+ },
+ },
+ {
+ "with query response type id token",
+ args{
+ "uri?param=value",
+ oidc.ResponseTypeIDToken,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?param=value#test=test",
+ nil,
+ },
+ },
+ {
+ "with existing query",
+ args{
+ "uri?test=value",
+ oidc.ResponseTypeCode,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=value&test=test",
+ nil,
+ },
+ },
+ {
+ "with existing query response type id token",
+ args{
+ "uri?test=value",
+ oidc.ResponseTypeIDToken,
+ "",
+ map[string][]string{"test": {"test"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=value#test=test",
+ nil,
+ },
+ },
+ {
+ "with existing query and multiple values",
+ args{
+ "uri?test=value",
+ oidc.ResponseTypeCode,
+ "",
+ map[string][]string{"test": {"test", "test2"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=value&test=test&test=test2",
+ nil,
+ },
+ },
+ {
+ "with existing query and multiple values response type id token",
+ args{
+ "uri?test=value",
+ oidc.ResponseTypeIDToken,
+ "",
+ map[string][]string{"test": {"test", "test2"}},
+ &mockEncoder{},
+ },
+ res{
+ "uri?test=value#test=test&test=test2",
+ nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.AuthResponseURL(tt.args.redirectURI, tt.args.responseType, tt.args.responseMode, tt.args.response, tt.args.encoder)
+ if tt.res.err == nil && err != nil {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v", err)
+ }
+ if tt.res.err != nil && !errors.Is(err, tt.res.err) {
+ t.Errorf("ValidateAuthRequest() unexpected error = %v, want = %v", err, tt.res.err)
+ }
+ if got != tt.res.url {
+ t.Errorf("AuthResponseURL() got = %v, want %v", got, tt.res.url)
+ }
+ })
+ }
+}
+
+type mockEncoder struct {
+ err error
+}
+
+func (m *mockEncoder) Encode(src any, dst map[string][]string) error {
+ if m.err != nil {
+ return m.err
+ }
+ for s, strings := range src.(map[string][]string) {
+ dst[s] = strings
+ }
+ return nil
+}
+
+// mockCrypto implements the op.Crypto interface
+// and in always equals out. (It doesn't crypt anything).
+// When returnErr != nil, that error is always returned instread.
+type mockCrypto struct {
+ returnErr error
+}
+
+func (c *mockCrypto) Encrypt(s string) (string, error) {
+ if c.returnErr != nil {
+ return "", c.returnErr
+ }
+ return s, nil
+}
+
+func (c *mockCrypto) Decrypt(s string) (string, error) {
+ if c.returnErr != nil {
+ return "", c.returnErr
+ }
+ return s, nil
+}
+
+func TestAuthResponseCode(t *testing.T) {
+ type args struct {
+ authReq op.AuthRequest
+ authorizer func(*testing.T) op.Authorizer
+ }
+ type res struct {
+ wantCode int
+ wantLocationHeader string
+ wantCacheControlHeader string
+ wantBody string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ name: "create code error",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{
+ returnErr: io.ErrClosedPipe,
+ })
+ authorizer.EXPECT().Logger().Return(slog.Default())
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: http.StatusBadRequest,
+ wantBody: "io: read/write on closed pipe\n",
+ },
+ },
+ {
+ name: "success with state",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: http.StatusFound,
+ wantLocationHeader: "/auth/callback/?code=id1&state=state1",
+ wantBody: "",
+ },
+ },
+ {
+ name: "success with state and session_state",
+ args: args{
+ authReq: &storage.AuthRequestWithSessionState{
+ AuthRequest: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "state1",
+ },
+ SessionState: "session_state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: http.StatusFound,
+ wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1",
+ wantBody: "",
+ },
+ },
+ {
+ name: "success without state", // reproduce issue #415
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: http.StatusFound,
+ wantLocationHeader: "/auth/callback/?code=id1",
+ wantBody: "",
+ },
+ },
+ {
+ name: "success form_post",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback",
+ TransferState: "state1",
+ ResponseMode: "form_post",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: http.StatusOK,
+ wantCacheControlHeader: "no-store",
+ wantBody: "\n\n \n\n\n\n",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodPost, "/auth/callback/", nil)
+ w := httptest.NewRecorder()
+ op.AuthResponseCode(w, r, tt.args.authReq, tt.args.authorizer(t))
+ resp := w.Result()
+ defer resp.Body.Close()
+ assert.Equal(t, tt.res.wantCode, resp.StatusCode)
+ assert.Equal(t, tt.res.wantLocationHeader, resp.Header.Get("Location"))
+ assert.Equal(t, tt.res.wantCacheControlHeader, resp.Header.Get("Cache-Control"))
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, tt.res.wantBody, string(body))
+ })
+ }
+}
+
+func Test_parseAuthorizeCallbackRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ url string
+ wantId string
+ wantErr bool
+ }{
+ {
+ name: "parse error",
+ url: "/?id;=99",
+ wantErr: true,
+ },
+ {
+ name: "missing id",
+ url: "/",
+ wantErr: true,
+ },
+ {
+ name: "ok",
+ url: "/?id=99",
+ wantId: "99",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodGet, tt.url, nil)
+ gotId, err := op.ParseAuthorizeCallbackRequest(r)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.wantId, gotId)
+ })
+ }
+}
+
+func TestBuildAuthResponseCodeResponsePayload(t *testing.T) {
+ type args struct {
+ authReq op.AuthRequest
+ authorizer func(*testing.T) op.Authorizer
+ }
+ type res struct {
+ wantCode string
+ wantState string
+ wantSessionState string
+ wantErr bool
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ name: "create code error",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{
+ returnErr: io.ErrClosedPipe,
+ })
+ return authorizer
+ },
+ },
+ res: res{
+ wantErr: true,
+ },
+ },
+ {
+ name: "success with state",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: "id1",
+ wantState: "state1",
+ },
+ },
+ {
+ name: "success without state",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: "id1",
+ wantState: "",
+ },
+ },
+ {
+ name: "success with session_state",
+ args: args{
+ authReq: &storage.AuthRequestWithSessionState{
+ AuthRequest: &storage.AuthRequest{
+ ID: "id1",
+ TransferState: "state1",
+ },
+ SessionState: "session_state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ return authorizer
+ },
+ },
+ res: res{
+ wantCode: "id1",
+ wantState: "state1",
+ wantSessionState: "session_state1",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.BuildAuthResponseCodeResponsePayload(context.Background(), tt.args.authReq, tt.args.authorizer(t))
+ if tt.res.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tt.res.wantCode, got.Code)
+ assert.Equal(t, tt.res.wantState, got.State)
+ assert.Equal(t, tt.res.wantSessionState, got.SessionState)
+ })
+ }
+}
+
+func TestValidateAuthReqIDTokenHint(t *testing.T) {
+ token, _ := tu.ValidIDToken()
+ tests := []struct {
+ name string
+ idTokenHint string
+ want string
+ wantErr error
+ }{
+ {
+ name: "empty",
+ },
+ {
+ name: "verify err",
+ idTokenHint: "foo",
+ wantErr: oidc.ErrLoginRequired(),
+ },
+ {
+ name: "ok",
+ idTokenHint: token,
+ want: tu.ValidSubject,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.ValidateAuthReqIDTokenHint(context.Background(), tt.idTokenHint, op.NewIDTokenHintVerifier(tu.ValidIssuer, tu.KeySet{}))
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestBuildAuthResponseCallbackURL(t *testing.T) {
+ type args struct {
+ authReq op.AuthRequest
+ authorizer func(*testing.T) op.Authorizer
+ }
+ type res struct {
+ wantURL string
+ wantErr bool
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ name: "error when generating code response",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{
+ returnErr: io.ErrClosedPipe,
+ })
+ return authorizer
+ },
+ },
+ res: res{
+ wantErr: true,
+ },
+ },
+ {
+ name: "error when generating callback URL",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "://invalid-url",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantErr: true,
+ },
+ },
+ {
+ name: "success with state",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback",
+ TransferState: "state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantURL: "https://example.com/callback?code=id1&state=state1",
+ wantErr: false,
+ },
+ },
+ {
+ name: "success without state",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantURL: "https://example.com/callback?code=id1",
+ wantErr: false,
+ },
+ },
+ {
+ name: "success with session_state",
+ args: args{
+ authReq: &storage.AuthRequestWithSessionState{
+ AuthRequest: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback",
+ TransferState: "state1",
+ },
+ SessionState: "session_state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantURL: "https://example.com/callback?code=id1&session_state=session_state1&state=state1",
+ wantErr: false,
+ },
+ },
+ {
+ name: "success with existing query parameters",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback?param=value",
+ TransferState: "state1",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantURL: "https://example.com/callback?param=value&code=id1&state=state1",
+ wantErr: false,
+ },
+ },
+ {
+ name: "success with fragment response mode",
+ args: args{
+ authReq: &storage.AuthRequest{
+ ID: "id1",
+ CallbackURI: "https://example.com/callback",
+ TransferState: "state1",
+ ResponseMode: "fragment",
+ },
+ authorizer: func(t *testing.T) op.Authorizer {
+ ctrl := gomock.NewController(t)
+ storage := mock.NewMockStorage(ctrl)
+ storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
+
+ authorizer := mock.NewMockAuthorizer(ctrl)
+ authorizer.EXPECT().Storage().Return(storage)
+ authorizer.EXPECT().Crypto().Return(&mockCrypto{})
+ authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
+ return authorizer
+ },
+ },
+ res: res{
+ wantURL: "https://example.com/callback#code=id1&state=state1",
+ wantErr: false,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.BuildAuthResponseCallbackURL(context.Background(), tt.args.authReq, tt.args.authorizer(t))
+ if tt.res.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+
+ if tt.res.wantURL != "" {
+ // Parse the URLs to compare components instead of direct string comparison
+ expectedURL, err := url.Parse(tt.res.wantURL)
+ require.NoError(t, err)
+ actualURL, err := url.Parse(got)
+ require.NoError(t, err)
+
+ // Compare the base parts (scheme, host, path)
+ assert.Equal(t, expectedURL.Scheme, actualURL.Scheme)
+ assert.Equal(t, expectedURL.Host, actualURL.Host)
+ assert.Equal(t, expectedURL.Path, actualURL.Path)
+
+ // Compare the fragment if any
+ assert.Equal(t, expectedURL.Fragment, actualURL.Fragment)
+
+ // For query parameters, compare them independently of order
+ expectedQuery := expectedURL.Query()
+ actualQuery := actualURL.Query()
+
+ assert.Equal(t, len(expectedQuery), len(actualQuery), "Query parameter count does not match")
+
+ for key, expectedValues := range expectedQuery {
+ actualValues, exists := actualQuery[key]
+ assert.True(t, exists, "Expected query parameter %s not found", key)
+ assert.ElementsMatch(t, expectedValues, actualValues, "Values for parameter %s don't match", key)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go
deleted file mode 100644
index c25f60d..0000000
--- a/pkg/op/authrequest.go
+++ /dev/null
@@ -1,217 +0,0 @@
-package op
-
-import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "strings"
-
- "github.com/gorilla/mux"
- "github.com/gorilla/schema"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
- "github.com/caos/oidc/pkg/utils"
-)
-
-type Authorizer interface {
- Storage() Storage
- Decoder() *schema.Decoder
- Encoder() *schema.Encoder
- Signer() Signer
- IDTokenVerifier() rp.Verifier
- Crypto() Crypto
- Issuer() string
-}
-
-type ValidationAuthorizer interface {
- Authorizer
- ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error)
-}
-
-func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
- err := r.ParseForm()
- if err != nil {
- AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
- return
- }
- authReq := new(oidc.AuthRequest)
- err = authorizer.Decoder().Decode(authReq, r.Form)
- if err != nil {
- AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
- return
- }
- validation := ValidateAuthRequest
- if validater, ok := authorizer.(ValidationAuthorizer); ok {
- validation = validater.ValidateAuthRequest
- }
- userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier())
- if err != nil {
- AuthRequestError(w, r, authReq, err, authorizer.Encoder())
- return
- }
- req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq, userID)
- if err != nil {
- AuthRequestError(w, r, authReq, err, authorizer.Encoder())
- return
- }
- client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
- if err != nil {
- AuthRequestError(w, r, req, err, authorizer.Encoder())
- return
- }
- RedirectToLogin(req.GetID(), client, w, r)
-}
-
-func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) {
- if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
- return "", err
- }
- if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
- return "", err
- }
- if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
- return "", err
- }
- return ValidateAuthReqIDTokenHint(ctx, authReq.IDTokenHint, verifier)
-}
-
-func ValidateAuthReqScopes(scopes []string) error {
- if len(scopes) == 0 {
- return ErrInvalidRequest("scope missing")
- }
- if !utils.Contains(scopes, oidc.ScopeOpenID) {
- return ErrInvalidRequest("scope openid missing")
- }
- return nil
-}
-
-func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
- if uri == "" {
- return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
- }
- client, err := storage.GetClientByClientID(ctx, client_id)
- if err != nil {
- return ErrServerError(err.Error())
- }
- if !utils.Contains(client.RedirectURIs(), uri) {
- return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
- }
- if strings.HasPrefix(uri, "https://") {
- return nil
- }
- if responseType == oidc.ResponseTypeCode {
- if strings.HasPrefix(uri, "http://") && IsConfidentialType(client) {
- return nil
- }
- if client.ApplicationType() == ApplicationTypeNative {
- return nil
- }
- return ErrInvalidRequest("redirect_uri not allowed")
- } else {
- if client.ApplicationType() != ApplicationTypeNative {
- return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
- }
- if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
- return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
- }
- }
- return nil
-}
-
-func ValidateAuthReqResponseType(responseType oidc.ResponseType) error {
- if responseType == "" {
- return ErrInvalidRequest("response_type empty")
- }
- return nil
-}
-
-func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) {
- if idTokenHint == "" {
- return "", nil
- }
- claims, err := verifier.Verify(ctx, "", idTokenHint)
- if err != nil {
- return "", ErrInvalidRequest("id_token_hint invalid")
- }
- return claims.Subject, nil
-}
-
-func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
- login := client.LoginURL(authReqID)
- http.Redirect(w, r, login, http.StatusFound)
-}
-
-func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
- params := mux.Vars(r)
- id := params["id"]
-
- authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
- if err != nil {
- AuthRequestError(w, r, nil, err, authorizer.Encoder())
- return
- }
- if !authReq.Done() {
- AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
- return
- }
- AuthResponse(authReq, authorizer, w, r)
-}
-
-func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
- client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
- if err != nil {
-
- }
- if authReq.GetResponseType() == oidc.ResponseTypeCode {
- AuthResponseCode(w, r, authReq, authorizer)
- return
- }
- AuthResponseToken(w, r, authReq, authorizer, client)
- return
-}
-
-func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
- code, err := CreateAuthRequestCode(r.Context(), authReq, authorizer.Storage(), authorizer.Crypto())
- if err != nil {
- AuthRequestError(w, r, authReq, err, authorizer.Encoder())
- return
- }
- callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
- if authReq.GetState() != "" {
- callback = callback + "&state=" + authReq.GetState()
- }
- http.Redirect(w, r, callback, http.StatusFound)
-}
-
-func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
- createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
- resp, err := CreateTokenResponse(r.Context(), authReq, client, authorizer, createAccessToken, "")
- if err != nil {
- AuthRequestError(w, r, authReq, err, authorizer.Encoder())
- return
- }
- params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
- if err != nil {
- AuthRequestError(w, r, authReq, err, authorizer.Encoder())
- return
- }
- callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
- http.Redirect(w, r, callback, http.StatusFound)
-}
-
-func CreateAuthRequestCode(ctx context.Context, authReq AuthRequest, storage Storage, crypto Crypto) (string, error) {
- code, err := BuildAuthRequestCode(authReq, crypto)
- if err != nil {
- return "", err
- }
- if err := storage.SaveAuthCode(ctx, authReq.GetID(), code); err != nil {
- return "", err
- }
- return code, nil
-}
-
-func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
- return crypto.Encrypt(authReq.GetID())
-}
diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go
deleted file mode 100644
index dca72fa..0000000
--- a/pkg/op/authrequest_test.go
+++ /dev/null
@@ -1,299 +0,0 @@
-package op_test
-
-import (
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/require"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/op"
- "github.com/caos/oidc/pkg/op/mock"
- "github.com/caos/oidc/pkg/rp"
-)
-
-func TestAuthorize(t *testing.T) {
- // testCallback := func(t *testing.T, clienID string) callbackHandler {
- // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
- // // require.Equal(t, clientID, client.)
- // }
- // }
- // testErr := func(t *testing.T, expected error) errorHandler {
- // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
- // require.Equal(t, expected, err)
- // }
- // }
- type args struct {
- w http.ResponseWriter
- r *http.Request
- authorizer op.Authorizer
- }
- tests := []struct {
- name string
- args args
- }{
- {
- "parsing fails",
- args{
- httptest.NewRecorder(),
- &http.Request{Method: "POST", Body: nil},
- mock.NewAuthorizerExpectValid(t, true),
- // testCallback(t, ""),
- // testErr(t, ErrInvalidRequest("cannot parse form")),
- },
- },
- {
- "decoding fails",
- args{
- httptest.NewRecorder(),
- func() *http.Request {
- r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
- r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- return r
- }(),
- mock.NewAuthorizerExpectValid(t, true),
- // testCallback(t, ""),
- // testErr(t, ErrInvalidRequest("cannot parse auth request")),
- },
- },
- // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
- })
- }
-}
-
-func TestValidateAuthRequest(t *testing.T) {
- type args struct {
- authRequest *oidc.AuthRequest
- storage op.Storage
- verifier rp.Verifier
- }
- tests := []struct {
- name string
- args args
- wantErr bool
- }{
- //TODO:
- // {
- // "oauth2 spec"
- // }
- {
- "scope missing fails",
- args{&oidc.AuthRequest{}, nil, nil},
- true,
- },
- {
- "scope openid missing fails",
- args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil, nil},
- true,
- },
- {
- "response_type missing fails",
- args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil, nil},
- true,
- },
- {
- "client_id missing fails",
- args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil, nil},
- true,
- },
- {
- "redirect_uri missing fails",
- args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil, nil},
- true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- _, err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage, tt.args.verifier)
- if (err != nil) != tt.wantErr {
- t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestValidateAuthReqScopes(t *testing.T) {
- type args struct {
- scopes []string
- }
- tests := []struct {
- name string
- args args
- wantErr bool
- }{
- {
- "scopes missing fails", args{}, true,
- },
- {
- "scope openid missing fails", args{[]string{"email"}}, true,
- },
- {
- "scope ok", args{[]string{"openid"}}, false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
- t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestValidateAuthReqRedirectURI(t *testing.T) {
- type args struct {
- uri string
- clientID string
- responseType oidc.ResponseType
- storage op.OPStorage
- }
- tests := []struct {
- name string
- args args
- wantErr bool
- }{
- {
- "empty fails",
- args{"", "", oidc.ResponseTypeCode, nil},
- true,
- },
- {
- "unregistered fails",
- args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- {
- "storage error fails",
- args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)},
- true,
- },
- {
- "code flow registered http not confidential fails",
- args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- {
- "code flow registered http confidential ok",
- args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
- false,
- },
- {
- "code flow registered custom not native fails",
- args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- {
- "code flow registered custom native ok",
- args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
- false,
- },
- {
- "implicit flow registered ok",
- args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
- false,
- },
- {
- "implicit flow registered http localhost native ok",
- args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
- false,
- },
- {
- "implicit flow registered http localhost user agent fails",
- args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- {
- "implicit flow http non localhost fails",
- args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- {
- "implicit flow custom fails",
- args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
- true,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if err := op.ValidateAuthReqRedirectURI(nil, tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr {
- t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr)
- }
- })
- }
-}
-
-func TestRedirectToLogin(t *testing.T) {
- type args struct {
- authReqID string
- client op.Client
- w http.ResponseWriter
- r *http.Request
- }
- tests := []struct {
- name string
- args args
- }{
- {
- "redirect ok",
- args{
- "id",
- mock.NewClientExpectAny(t, op.ApplicationTypeNative),
- httptest.NewRecorder(),
- httptest.NewRequest("GET", "/authorize", nil),
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
- rec := tt.args.w.(*httptest.ResponseRecorder)
- require.Equal(t, http.StatusFound, rec.Code)
- require.Equal(t, "/login?id=id", rec.Header().Get("location"))
- })
- }
-}
-
-func TestAuthorizeCallback(t *testing.T) {
- type args struct {
- w http.ResponseWriter
- r *http.Request
- authorizer op.Authorizer
- }
- tests := []struct {
- name string
- args args
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
- })
- }
-}
-
-func TestAuthResponse(t *testing.T) {
- type args struct {
- authReq op.AuthRequest
- authorizer op.Authorizer
- w http.ResponseWriter
- r *http.Request
- }
- tests := []struct {
- name string
- args args
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
- })
- }
-}
diff --git a/pkg/op/client.go b/pkg/op/client.go
index a61e31d..a4f44d3 100644
--- a/pkg/op/client.go
+++ b/pkg/op/client.go
@@ -1,33 +1,201 @@
package op
-import "time"
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/url"
+ "time"
-const (
- ApplicationTypeWeb ApplicationType = iota
- ApplicationTypeUserAgent
- ApplicationTypeNative
-
- AccessTokenTypeBearer AccessTokenType = iota
- AccessTokenTypeJWT
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
-type Client interface {
- GetID() string
- RedirectURIs() []string
- PostLogoutRedirectURIs() []string
- ApplicationType() ApplicationType
- GetAuthMethod() AuthMethod
- LoginURL(string) string
- AccessTokenType() AccessTokenType
- IDTokenLifetime() time.Duration
-}
+//go:generate go get github.com/dmarkham/enumer
+//go:generate go run github.com/dmarkham/enumer -linecomment -sql -json -text -yaml -gqlgen -type=ApplicationType,AccessTokenType
+//go:generate go mod tidy
-func IsConfidentialType(c Client) bool {
- return c.ApplicationType() == ApplicationTypeWeb
-}
+const (
+ ApplicationTypeWeb ApplicationType = iota // web
+ ApplicationTypeUserAgent // user_agent
+ ApplicationTypeNative // native
+)
+
+const (
+ AccessTokenTypeBearer AccessTokenType = iota // bearer
+ AccessTokenTypeJWT // JWT
+)
type ApplicationType int
type AuthMethod string
type AccessTokenType int
+
+type Client interface {
+ GetID() string
+ RedirectURIs() []string
+ PostLogoutRedirectURIs() []string
+ ApplicationType() ApplicationType
+ AuthMethod() oidc.AuthMethod
+ ResponseTypes() []oidc.ResponseType
+ GrantTypes() []oidc.GrantType
+ LoginURL(string) string
+ AccessTokenType() AccessTokenType
+ IDTokenLifetime() time.Duration
+ DevMode() bool
+ RestrictAdditionalIdTokenScopes() func(scopes []string) []string
+ RestrictAdditionalAccessTokenScopes() func(scopes []string) []string
+ IsScopeAllowed(scope string) bool
+ IDTokenUserinfoClaimsAssertion() bool
+ ClockSkew() time.Duration
+}
+
+// HasRedirectGlobs is an optional interface that can be implemented by implementors of
+// Client. See https://pkg.go.dev/path#Match for glob
+// interpretation. Redirect URIs that match either the non-glob version or the
+// glob version will be accepted. Glob URIs are only partially supported for native
+// clients: "http://" is not allowed except for loopback or in dev mode.
+//
+// Note that globbing / wildcards are not permitted by the OIDC
+// standard and implementing this interface can have security implications.
+// It is advised to only return a client of this type in rare cases,
+// such as DevMode for the client being enabled.
+// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
+type HasRedirectGlobs interface {
+ Client
+ RedirectURIGlobs() []string
+ PostLogoutRedirectURIGlobs() []string
+}
+
+func ContainsResponseType(types []oidc.ResponseType, responseType oidc.ResponseType) bool {
+ for _, t := range types {
+ if t == responseType {
+ return true
+ }
+ }
+ return false
+}
+
+func IsConfidentialType(c Client) bool {
+ return c.ApplicationType() == ApplicationTypeWeb
+}
+
+var (
+ ErrInvalidAuthHeader = errors.New("invalid basic auth header")
+ ErrNoClientCredentials = errors.New("no client credentials provided")
+ ErrMissingClientID = errors.New("client_id missing from request")
+)
+
+type ClientJWTProfile interface {
+ JWTProfileVerifier(context.Context) *JWTProfileVerifier
+}
+
+func ClientJWTAuth(ctx context.Context, ca oidc.ClientAssertionParams, verifier ClientJWTProfile) (clientID string, err error) {
+ ctx, span := tracer.Start(ctx, "ClientJWTAuth")
+ defer span.End()
+
+ if ca.ClientAssertion == "" {
+ return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+
+ profile, err := VerifyJWTAssertion(ctx, ca.ClientAssertion, verifier.JWTProfileVerifier(ctx))
+ if err != nil {
+ return "", oidc.ErrUnauthorizedClient().WithParent(err).WithDescription("JWT assertion failed")
+ }
+ return profile.Issuer, nil
+}
+
+func ClientBasicAuth(r *http.Request, storage Storage) (clientID string, err error) {
+ ctx, span := tracer.Start(r.Context(), "ClientBasicAuth")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ clientID, clientSecret, ok := r.BasicAuth()
+ if !ok {
+ return "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
+ }
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return "", oidc.ErrInvalidClient().WithParent(ErrInvalidAuthHeader)
+ }
+ if err := storage.AuthorizeClientIDSecret(r.Context(), clientID, clientSecret); err != nil {
+ return "", oidc.ErrUnauthorizedClient().WithParent(err)
+ }
+ return clientID, nil
+}
+
+type ClientProvider interface {
+ Decoder() httphelper.Decoder
+ Storage() Storage
+}
+
+type clientData struct {
+ ClientID string `schema:"client_id"`
+ oidc.ClientAssertionParams
+}
+
+// ClientIDFromRequest parses the request form and tries to obtain the client ID
+// and reports if it is authenticated, using a JWT or static client secrets over
+// http basic auth.
+//
+// If the Provider implements IntrospectorJWTProfile and "client_assertion" is
+// present in the form data, JWT assertion will be verified and the
+// client ID is taken from there.
+// If any of them is absent, basic auth is attempted.
+// In absence of basic auth data, the unauthenticated client id from the form
+// data is returned.
+//
+// If no client id can be obtained by any method, oidc.ErrInvalidClient
+// is returned with ErrMissingClientID wrapped in it.
+func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, authenticated bool, err error) {
+ err = r.ParseForm()
+ if err != nil {
+ return "", false, oidc.ErrInvalidRequest().WithDescription("cannot parse form").WithParent(err)
+ }
+
+ ctx, span := tracer.Start(r.Context(), "ClientIDFromRequest")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ data := new(clientData)
+ if err = p.Decoder().Decode(data, r.Form); err != nil {
+ return "", false, err
+ }
+
+ JWTProfile, ok := p.(ClientJWTProfile)
+ if ok && data.ClientAssertion != "" {
+ // if JWTProfile is supported and client sent an assertion, check it and use it as response
+ // regardless if it succeeded or failed
+ clientID, err = ClientJWTAuth(r.Context(), data.ClientAssertionParams, JWTProfile)
+ return clientID, err == nil, err
+ }
+ // try basic auth
+ clientID, err = ClientBasicAuth(r, p.Storage())
+ // if that succeeded, use it
+ if err == nil {
+ return clientID, true, nil
+ }
+ // if the client did not send a Basic Auth Header, ignore the `ErrNoClientCredentials`
+ // but return other errors immediately
+ if !errors.Is(err, ErrNoClientCredentials) {
+ return "", false, err
+ }
+
+ // if the client did not authenticate (public clients) it must at least send a client_id
+ if data.ClientID == "" {
+ return "", false, oidc.ErrInvalidClient().WithParent(ErrMissingClientID)
+ }
+ return data.ClientID, false, nil
+}
+
+type ClientCredentials struct {
+ ClientID string `schema:"client_id"`
+ ClientSecret string `schema:"client_secret"` // Client secret from Basic auth or request body
+ ClientAssertion string `schema:"client_assertion"` // JWT
+ ClientAssertionType string `schema:"client_assertion_type"`
+}
diff --git a/pkg/op/client_test.go b/pkg/op/client_test.go
new file mode 100644
index 0000000..b416630
--- /dev/null
+++ b/pkg/op/client_test.go
@@ -0,0 +1,253 @@
+package op_test
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/schema"
+)
+
+type testClientJWTProfile struct{}
+
+func (testClientJWTProfile) JWTProfileVerifier(context.Context) *op.JWTProfileVerifier { return nil }
+
+func TestClientJWTAuth(t *testing.T) {
+ type args struct {
+ ctx context.Context
+ ca oidc.ClientAssertionParams
+ verifier op.ClientJWTProfile
+ }
+ tests := []struct {
+ name string
+ args args
+ wantClientID string
+ wantErr error
+ }{
+ {
+ name: "empty assertion",
+ args: args{
+ context.Background(),
+ oidc.ClientAssertionParams{},
+ testClientJWTProfile{},
+ },
+ wantErr: op.ErrNoClientCredentials,
+ },
+ {
+ name: "verification error",
+ args: args{
+ context.Background(),
+ oidc.ClientAssertionParams{
+ ClientAssertion: "foo",
+ },
+ testClientJWTProfile{},
+ },
+ wantErr: oidc.ErrParse,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotClientID, err := op.ClientJWTAuth(tt.args.ctx, tt.args.ca, tt.args.verifier)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ })
+ }
+}
+
+func TestClientBasicAuth(t *testing.T) {
+ errWrong := errors.New("wrong secret")
+
+ type args struct {
+ username string
+ password string
+ }
+ tests := []struct {
+ name string
+ args *args
+ storage op.Storage
+ wantClientID string
+ wantErr error
+ }{
+ {
+ name: "no args",
+ wantErr: op.ErrNoClientCredentials,
+ },
+ {
+ name: "username unescape err",
+ args: &args{
+ username: "%",
+ password: "bar",
+ },
+ wantErr: op.ErrInvalidAuthHeader,
+ },
+ {
+ name: "password unescape err",
+ args: &args{
+ username: "foo",
+ password: "%",
+ },
+ wantErr: op.ErrInvalidAuthHeader,
+ },
+ {
+ name: "auth error",
+ args: &args{
+ username: "foo",
+ password: "wrong",
+ },
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "wrong").Return(errWrong)
+ return s
+ }(),
+ wantErr: errWrong,
+ },
+ {
+ name: "auth error",
+ args: &args{
+ username: "foo",
+ password: "bar",
+ },
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil)
+ return s
+ }(),
+ wantClientID: "foo",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodGet, "/foo", nil)
+ if tt.args != nil {
+ r.SetBasicAuth(tt.args.username, tt.args.password)
+ }
+
+ gotClientID, err := op.ClientBasicAuth(r, tt.storage)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ })
+ }
+}
+
+type errReader struct{}
+
+func (errReader) Read([]byte) (int, error) {
+ return 0, io.ErrNoProgress
+}
+
+type testClientProvider struct {
+ storage op.Storage
+}
+
+func (testClientProvider) Decoder() httphelper.Decoder {
+ return schema.NewDecoder()
+}
+
+func (p testClientProvider) Storage() op.Storage {
+ return p.storage
+}
+
+func TestClientIDFromRequest(t *testing.T) {
+ type args struct {
+ body io.Reader
+ p op.ClientProvider
+ }
+ type basicAuth struct {
+ username string
+ password string
+ }
+ tests := []struct {
+ name string
+ args args
+ basicAuth *basicAuth
+ wantClientID string
+ wantAuthenticated bool
+ wantErr bool
+ }{
+ {
+ name: "parse error",
+ args: args{
+ body: errReader{},
+ },
+ wantErr: true,
+ },
+ {
+ name: "unauthenticated",
+ args: args{
+ body: strings.NewReader(
+ url.Values{
+ "client_id": []string{"foo"},
+ }.Encode(),
+ ),
+ p: testClientProvider{
+ storage: mock.NewStorage(t),
+ },
+ },
+ wantClientID: "foo",
+ wantAuthenticated: false,
+ },
+ {
+ name: "authenticated",
+ args: args{
+ body: strings.NewReader(
+ url.Values{}.Encode(),
+ ),
+ p: testClientProvider{
+ storage: func() op.Storage {
+ s := mock.NewMockStorage(gomock.NewController(t))
+ s.EXPECT().AuthorizeClientIDSecret(gomock.Any(), "foo", "bar").Return(nil)
+ return s
+ }(),
+ },
+ },
+ basicAuth: &basicAuth{
+ username: "foo",
+ password: "bar",
+ },
+ wantClientID: "foo",
+ wantAuthenticated: true,
+ },
+ {
+ name: "missing client id",
+ args: args{
+ body: strings.NewReader(
+ url.Values{}.Encode(),
+ ),
+ p: testClientProvider{
+ storage: mock.NewStorage(t),
+ },
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ if tt.basicAuth != nil {
+ r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
+ }
+
+ gotClientID, gotAuthenticated, err := op.ClientIDFromRequest(r, tt.args.p)
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.wantClientID, gotClientID)
+ assert.Equal(t, tt.wantAuthenticated, gotAuthenticated)
+ })
+ }
+}
diff --git a/pkg/op/config.go b/pkg/op/config.go
index c52609a..b271765 100644
--- a/pkg/op/config.go
+++ b/pkg/op/config.go
@@ -2,52 +2,183 @@ package op
import (
"errors"
+ "log"
+ "net/http"
"net/url"
- "os"
"strings"
+
+ "github.com/muhlemmer/httpforwarded"
+ "golang.org/x/text/language"
+)
+
+var (
+ ErrInvalidIssuerPath = errors.New("no fragments or query allowed for issuer")
+ ErrInvalidIssuerNoIssuer = errors.New("missing issuer")
+ ErrInvalidIssuerURL = errors.New("invalid url for issuer")
+ ErrInvalidIssuerMissingHost = errors.New("host for issuer missing")
+ ErrInvalidIssuerHTTPS = errors.New("scheme for issuer must be `https`")
)
type Configuration interface {
- Issuer() string
- AuthorizationEndpoint() Endpoint
- TokenEndpoint() Endpoint
- UserinfoEndpoint() Endpoint
- EndSessionEndpoint() Endpoint
- KeysEndpoint() Endpoint
+ IssuerFromRequest(r *http.Request) string
+ Insecure() bool
+ AuthorizationEndpoint() *Endpoint
+ TokenEndpoint() *Endpoint
+ IntrospectionEndpoint() *Endpoint
+ UserinfoEndpoint() *Endpoint
+ RevocationEndpoint() *Endpoint
+ EndSessionEndpoint() *Endpoint
+ KeysEndpoint() *Endpoint
+ DeviceAuthorizationEndpoint() *Endpoint
+ CheckSessionIframe() *Endpoint
AuthMethodPostSupported() bool
+ CodeMethodS256Supported() bool
+ AuthMethodPrivateKeyJWTSupported() bool
+ TokenEndpointSigningAlgorithmsSupported() []string
+ GrantTypeRefreshTokenSupported() bool
+ GrantTypeTokenExchangeSupported() bool
+ GrantTypeJWTAuthorizationSupported() bool
+ GrantTypeClientCredentialsSupported() bool
+ GrantTypeDeviceCodeSupported() bool
+ IntrospectionAuthMethodPrivateKeyJWTSupported() bool
+ IntrospectionEndpointSigningAlgorithmsSupported() []string
+ RevocationAuthMethodPrivateKeyJWTSupported() bool
+ RevocationEndpointSigningAlgorithmsSupported() []string
+ RequestObjectSupported() bool
+ RequestObjectSigningAlgorithmsSupported() []string
+
+ SupportedUILocales() []language.Tag
+ DeviceAuthorization() DeviceAuthorizationConfig
+
+ BackChannelLogoutSupported() bool
+ BackChannelLogoutSessionSupported() bool
}
-func ValidateIssuer(issuer string) error {
+type IssuerFromRequest func(r *http.Request) string
+
+func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
+ return issuerFromForwardedOrHost(path, new(issuerConfig))
+}
+
+type IssuerFromOption func(c *issuerConfig)
+
+// WithIssuerFromCustomHeaders can be used to customize the header names used.
+// The same rules apply where the first successful host is returned.
+func WithIssuerFromCustomHeaders(headers ...string) IssuerFromOption {
+ return func(c *issuerConfig) {
+ for i, h := range headers {
+ headers[i] = http.CanonicalHeaderKey(h)
+ }
+ c.headers = headers
+ }
+}
+
+type issuerConfig struct {
+ headers []string
+}
+
+// IssuerFromForwardedOrHost tries to establish the Issuer based
+// on the Forwarded header host field.
+// If multiple Forwarded headers are present, the first mention
+// of the host field will be used.
+// If the Forwarded header is not present, no host field is found,
+// or there is a parser error the Request Host will be used as a fallback.
+// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded
+func IssuerFromForwardedOrHost(path string, opts ...IssuerFromOption) func(bool) (IssuerFromRequest, error) {
+ c := &issuerConfig{
+ headers: []string{http.CanonicalHeaderKey("forwarded")},
+ }
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ return issuerFromForwardedOrHost(path, c)
+}
+
+func issuerFromForwardedOrHost(path string, c *issuerConfig) func(bool) (IssuerFromRequest, error) {
+ return func(allowInsecure bool) (IssuerFromRequest, error) {
+ issuerPath, err := url.Parse(path)
+ if err != nil {
+ return nil, ErrInvalidIssuerURL
+ }
+ if err := ValidateIssuerPath(issuerPath); err != nil {
+ return nil, err
+ }
+ return func(r *http.Request) string {
+ if host, ok := hostFromForwarded(r, c.headers); ok {
+ return dynamicIssuer(host, path, allowInsecure)
+ }
+ return dynamicIssuer(r.Host, path, allowInsecure)
+ }, nil
+ }
+}
+
+func hostFromForwarded(r *http.Request, headers []string) (host string, ok bool) {
+ for _, header := range headers {
+ hosts, err := httpforwarded.ParseParameter("host", r.Header[header])
+ if err != nil {
+ log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch
+ continue
+ }
+ if len(hosts) > 0 {
+ return hosts[0], true
+ }
+ }
+ return "", false
+}
+
+func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
+ return func(allowInsecure bool) (IssuerFromRequest, error) {
+ if err := ValidateIssuer(issuer, allowInsecure); err != nil {
+ return nil, err
+ }
+ return func(_ *http.Request) string {
+ return issuer
+ }, nil
+ }
+}
+
+func ValidateIssuer(issuer string, allowInsecure bool) error {
if issuer == "" {
- return errors.New("missing issuer")
+ return ErrInvalidIssuerNoIssuer
}
u, err := url.Parse(issuer)
if err != nil {
- return errors.New("invalid url for issuer")
+ return ErrInvalidIssuerURL
}
if u.Host == "" {
- return errors.New("host for issuer missing")
+ return ErrInvalidIssuerMissingHost
}
if u.Scheme != "https" {
- if !devLocalAllowed(u) {
- return errors.New("scheme for issuer must be `https`")
+ if !devLocalAllowed(u, allowInsecure) {
+ return ErrInvalidIssuerHTTPS
}
}
- if u.Fragment != "" || len(u.Query()) > 0 {
- return errors.New("no fragments or query allowed for issuer")
+ return ValidateIssuerPath(u)
+}
+
+func ValidateIssuerPath(issuer *url.URL) error {
+ if issuer.Fragment != "" || len(issuer.Query()) > 0 {
+ return ErrInvalidIssuerPath
}
return nil
}
-func devLocalAllowed(url *url.URL) bool {
- _, b := os.LookupEnv("CAOS_OIDC_DEV")
- if !b {
- return b
+func devLocalAllowed(url *url.URL, allowInsecure bool) bool {
+ if !allowInsecure {
+ return false
}
- return url.Scheme == "http" &&
- url.Host == "localhost" ||
- url.Host == "127.0.0.1" ||
- url.Host == "::1" ||
- strings.HasPrefix(url.Host, "localhost:")
+ return url.Scheme == "http"
+}
+
+func dynamicIssuer(issuer, path string, allowInsecure bool) string {
+ schema := "https"
+ if allowInsecure {
+ schema = "http"
+ }
+ if len(path) > 0 && !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ return schema + "://" + issuer + path
}
diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go
index 56cf2eb..d739348 100644
--- a/pkg/op/config_test.go
+++ b/pkg/op/config_test.go
@@ -1,12 +1,19 @@
package op
-import "testing"
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
-import "os"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
func TestValidateIssuer(t *testing.T) {
type args struct {
- issuer string
+ issuer string
+ allowInsecure bool
}
tests := []struct {
name string
@@ -15,62 +22,97 @@ func TestValidateIssuer(t *testing.T) {
}{
{
"missing issuer fails",
- args{""},
+ args{
+ issuer: "",
+ },
true,
},
{
"invalid url for issuer fails",
- args{":issuer"},
- true,
- },
- {
- "invalid url for issuer fails",
- args{":issuer"},
+ args{
+ issuer: ":issuer",
+ },
true,
},
{
"host for issuer missing fails",
- args{"https:///issuer"},
- true,
- },
- {
- "host for not https fails",
- args{"http://issuer.com"},
+ args{
+ issuer: "https:///issuer",
+ },
true,
},
{
"host with fragment fails",
- args{"https://issuer.com/#issuer"},
+ args{
+ issuer: "https://issuer.com/#issuer",
+ },
true,
},
{
"host with query fails",
- args{"https://issuer.com?issuer=me"},
+ args{
+ issuer: "https://issuer.com?issuer=me",
+ },
+ true,
+ },
+ {
+ "host with http fails",
+ args{
+ issuer: "http://issuer.com",
+ },
true,
},
{
"host with https ok",
- args{"https://issuer.com"},
+ args{
+ issuer: "https://issuer.com",
+ },
false,
},
{
- "localhost with http ok",
- args{"http://localhost:9999"},
+ "custom scheme fails",
+ args{
+ issuer: "custom://localhost:9999",
+ },
+ true,
+ },
+ {
+ "http with allowInsecure ok",
+ args{
+ issuer: "http://localhost:9999",
+ allowInsecure: true,
+ },
+ false,
+ },
+ {
+ "https with allowInsecure ok",
+ args{
+ issuer: "https://localhost:9999",
+ allowInsecure: true,
+ },
+ false,
+ },
+ {
+ "custom scheme with allowInsecure fails",
+ args{
+ issuer: "custom://localhost:9999",
+ allowInsecure: true,
+ },
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
+ if err := ValidateIssuer(tt.args.issuer, tt.args.allowInsecure); (err != nil) != tt.wantErr {
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
-func TestValidateIssuerDevLocalAllowed(t *testing.T) {
+func TestValidateIssuerPath(t *testing.T) {
type args struct {
- issuer string
+ issuerPath *url.URL
}
tests := []struct {
name string
@@ -78,16 +120,343 @@ func TestValidateIssuerDevLocalAllowed(t *testing.T) {
wantErr bool
}{
{
- "localhost with http ok",
- args{"http://localhost:9999"},
+ "empty ok",
+ args{func() *url.URL {
+ u, _ := url.Parse("")
+ return u
+ }()},
false,
},
+ {
+ "custom ok",
+ args{func() *url.URL {
+ u, _ := url.Parse("/custom")
+ return u
+ }()},
+ false,
+ },
+ {
+ "fragment fails",
+ args{func() *url.URL {
+ u, _ := url.Parse("#fragment")
+ return u
+ }()},
+ true,
+ },
+ {
+ "query fails",
+ args{func() *url.URL {
+ u, _ := url.Parse("?query=value")
+ return u
+ }()},
+ true,
+ },
}
- os.Setenv("CAOS_OIDC_DEV", "")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
- t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
+ if err := ValidateIssuerPath(tt.args.issuerPath); (err != nil) != tt.wantErr {
+ t.Errorf("ValidateIssuerPath() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestIssuerFromHost(t *testing.T) {
+ type args struct {
+ path string
+ allowInsecure bool
+ target string
+ }
+ type res struct {
+ issuer string
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "invalid issuer path",
+ args{
+ path: "/#fragment",
+ allowInsecure: false,
+ },
+ res{
+ issuer: "",
+ err: ErrInvalidIssuerPath,
+ },
+ },
+ {
+ "empty path secure",
+ args{
+ path: "",
+ allowInsecure: false,
+ target: "https://issuer.com",
+ },
+ res{
+ issuer: "https://issuer.com",
+ err: nil,
+ },
+ },
+ {
+ "custom path secure",
+ args{
+ path: "/custom/",
+ allowInsecure: false,
+ target: "https://issuer.com",
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ err: nil,
+ },
+ },
+ {
+ "custom path no leading slash",
+ args{
+ path: "custom/",
+ allowInsecure: false,
+ target: "https://issuer.com",
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ err: nil,
+ },
+ },
+ {
+ "empty path unsecure",
+ args{
+ path: "",
+ allowInsecure: true,
+ target: "http://issuer.com",
+ },
+ res{
+ issuer: "http://issuer.com",
+ err: nil,
+ },
+ },
+ {
+ "custom path insecure",
+ args{
+ path: "/custom/",
+ allowInsecure: true,
+ target: "http://issuer.com",
+ },
+ res{
+ issuer: "http://issuer.com/custom/",
+ err: nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ issuer, err := IssuerFromHost(tt.args.path)(tt.args.allowInsecure)
+ if tt.res.err == nil {
+ assert.NoError(t, err)
+ req := httptest.NewRequest("", tt.args.target, nil)
+ assert.Equal(t, tt.res.issuer, issuer(req))
+ }
+ if tt.res.err != nil {
+ assert.ErrorIs(t, err, tt.res.err)
+ }
+ })
+ }
+}
+
+func TestIssuerFromForwardedOrHost(t *testing.T) {
+ type args struct {
+ path string
+ opts []IssuerFromOption
+ target string
+ header map[string][]string
+ }
+ type res struct {
+ issuer string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "header parse error",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{"Forwarded": {"~~~~"}},
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ },
+ },
+ {
+ "no forwarded header",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ },
+ },
+ // by=;for=;host=;proto=
+ {
+ "forwarded header without host",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{"Forwarded": {
+ `by=identifier;for=identifier;proto=https`,
+ }},
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ },
+ },
+ {
+ "forwarded header with host",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{"Forwarded": {
+ `by=identifier;for=identifier;host=first.com;proto=https`,
+ }},
+ },
+ res{
+ issuer: "https://first.com/custom/",
+ },
+ },
+ {
+ "forwarded header with multiple hosts",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{"Forwarded": {
+ `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
+ }},
+ },
+ res{
+ issuer: "https://first.com/custom/",
+ },
+ },
+ {
+ "multiple forwarded headers hosts",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{"Forwarded": {
+ `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
+ `by=identifier;for=identifier;host=third.com;proto=https`,
+ }},
+ },
+ res{
+ issuer: "https://first.com/custom/",
+ },
+ },
+ {
+ "custom header first",
+ args{
+ path: "/custom/",
+ target: "https://issuer.com",
+ header: map[string][]string{
+ "Forwarded": {
+ `by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
+ `by=identifier;for=identifier;host=third.com;proto=https`,
+ },
+ "X-Custom-Forwarded": {
+ `by=identifier;for=identifier;host=custom.com;proto=https,host=custom2.com`,
+ },
+ },
+ opts: []IssuerFromOption{
+ WithIssuerFromCustomHeaders("x-custom-forwarded"),
+ },
+ },
+ res{
+ issuer: "https://custom.com/custom/",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ issuer, err := IssuerFromForwardedOrHost(tt.args.path, tt.args.opts...)(false)
+ require.NoError(t, err)
+ req := httptest.NewRequest("", tt.args.target, nil)
+ for k, v := range tt.args.header {
+ req.Header[http.CanonicalHeaderKey(k)] = v
+ }
+ assert.Equal(t, tt.res.issuer, issuer(req))
+ })
+ }
+}
+
+func TestStaticIssuer(t *testing.T) {
+ type args struct {
+ issuer string
+ allowInsecure bool
+ }
+ type res struct {
+ issuer string
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ "invalid issuer",
+ args{
+ issuer: "",
+ allowInsecure: false,
+ },
+ res{
+ issuer: "",
+ err: ErrInvalidIssuerNoIssuer,
+ },
+ },
+ {
+ "empty path secure",
+ args{
+ issuer: "https://issuer.com",
+ allowInsecure: false,
+ },
+ res{
+ issuer: "https://issuer.com",
+ err: nil,
+ },
+ },
+ {
+ "custom path secure",
+ args{
+ issuer: "https://issuer.com/custom/",
+ allowInsecure: false,
+ },
+ res{
+ issuer: "https://issuer.com/custom/",
+ err: nil,
+ },
+ },
+ {
+ "unsecure",
+ args{
+ issuer: "http://issuer.com",
+ allowInsecure: true,
+ },
+ res{
+ issuer: "http://issuer.com",
+ err: nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ issuer, err := StaticIssuer(tt.args.issuer)(tt.args.allowInsecure)
+ if tt.res.err == nil {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.res.issuer, issuer(nil))
+ }
+ if tt.res.err != nil {
+ assert.ErrorIs(t, err, tt.res.err)
}
})
}
diff --git a/pkg/op/context.go b/pkg/op/context.go
new file mode 100644
index 0000000..7cff5a7
--- /dev/null
+++ b/pkg/op/context.go
@@ -0,0 +1,53 @@
+package op
+
+import (
+ "context"
+ "net/http"
+)
+
+type key int
+
+const (
+ issuerKey key = 0
+)
+
+type IssuerInterceptor struct {
+ issuerFromRequest IssuerFromRequest
+}
+
+// NewIssuerInterceptor will set the issuer into the context
+// by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
+func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
+ return &IssuerInterceptor{
+ issuerFromRequest: issuerFromRequest,
+ }
+}
+
+func (i *IssuerInterceptor) Handler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ i.setIssuerCtx(w, r, next)
+ })
+}
+
+func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ i.setIssuerCtx(w, r, next)
+ }
+}
+
+// IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
+// it will return an empty string if not found
+func IssuerFromContext(ctx context.Context) string {
+ ctxIssuer, _ := ctx.Value(issuerKey).(string)
+ return ctxIssuer
+}
+
+// ContextWithIssuer returns a new context with issuer set to it.
+func ContextWithIssuer(ctx context.Context, issuer string) context.Context {
+ return context.WithValue(ctx, issuerKey, issuer)
+}
+
+func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
+ r = r.WithContext(ContextWithIssuer(r.Context(), i.issuerFromRequest(r)))
+ next.ServeHTTP(w, r)
+}
diff --git a/pkg/op/context_test.go b/pkg/op/context_test.go
new file mode 100644
index 0000000..e6bfcec
--- /dev/null
+++ b/pkg/op/context_test.go
@@ -0,0 +1,76 @@
+package op
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestIssuerInterceptor(t *testing.T) {
+ type fields struct {
+ issuerFromRequest IssuerFromRequest
+ }
+ type args struct {
+ r *http.Request
+ next http.Handler
+ }
+ type res struct {
+ issuer string
+ }
+ tests := []struct {
+ name string
+ fields fields
+ args args
+ res res
+ }{
+ {
+ "empty",
+ fields{
+ func(r *http.Request) string {
+ return ""
+ },
+ },
+ args{},
+ res{
+ issuer: "",
+ },
+ },
+ {
+ "static",
+ fields{
+ func(r *http.Request) string {
+ return "static"
+ },
+ },
+ args{},
+ res{
+ issuer: "static",
+ },
+ },
+ {
+ "host",
+ fields{
+ func(r *http.Request) string {
+ return r.Host
+ },
+ },
+ args{},
+ res{
+ issuer: "issuer.com",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ i := NewIssuerInterceptor(tt.fields.issuerFromRequest)
+ next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, tt.res.issuer, IssuerFromContext(r.Context()))
+ })
+ req := httptest.NewRequest("", "https://issuer.com", nil)
+ i.Handler(next).ServeHTTP(nil, req)
+ i.HandlerFunc(next).ServeHTTP(nil, req)
+ })
+ }
+}
diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go
index e95157d..01aaad3 100644
--- a/pkg/op/crypto.go
+++ b/pkg/op/crypto.go
@@ -1,7 +1,7 @@
package op
import (
- "github.com/caos/oidc/pkg/utils"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
)
type Crypto interface {
@@ -18,9 +18,9 @@ func NewAESCrypto(key [32]byte) Crypto {
}
func (c *aesCrypto) Encrypt(s string) (string, error) {
- return utils.EncryptAES(s, c.key)
+ return crypto.EncryptAES(s, c.key)
}
func (c *aesCrypto) Decrypt(s string) (string, error) {
- return utils.DecryptAES(s, c.key)
+ return crypto.DecryptAES(s, c.key)
}
diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go
deleted file mode 100644
index a16d4d3..0000000
--- a/pkg/op/default_op.go
+++ /dev/null
@@ -1,345 +0,0 @@
-package op
-
-import (
- "context"
- "errors"
- "net/http"
- "time"
-
- "github.com/gorilla/schema"
- "gopkg.in/square/go-jose.v2"
-
- "github.com/caos/logging"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
-)
-
-const (
- defaultAuthorizationEndpoint = "authorize"
- defaulTokenEndpoint = "oauth/token"
- defaultIntrospectEndpoint = "introspect"
- defaultUserinfoEndpoint = "userinfo"
- defaultEndSessionEndpoint = "end_session"
- defaultKeysEndpoint = "keys"
-
- AuthMethodBasic AuthMethod = "client_secret_basic"
- AuthMethodPost = "client_secret_post"
- AuthMethodNone = "none"
-)
-
-var (
- DefaultEndpoints = &endpoints{
- Authorization: NewEndpoint(defaultAuthorizationEndpoint),
- Token: NewEndpoint(defaulTokenEndpoint),
- Introspection: NewEndpoint(defaultIntrospectEndpoint),
- Userinfo: NewEndpoint(defaultUserinfoEndpoint),
- EndSession: NewEndpoint(defaultEndSessionEndpoint),
- JwksURI: NewEndpoint(defaultKeysEndpoint),
- }
-)
-
-type DefaultOP struct {
- config *Config
- endpoints *endpoints
- storage Storage
- signer Signer
- verifier rp.Verifier
- crypto Crypto
- http http.Handler
- decoder *schema.Decoder
- encoder *schema.Encoder
- interceptor HttpInterceptor
- retry func(int) (bool, int)
- timer <-chan time.Time
-}
-
-type Config struct {
- Issuer string
- CryptoKey [32]byte
- DefaultLogoutRedirectURI string
- // ScopesSupported: oidc.SupportedScopes,
- // ResponseTypesSupported: responseTypes,
- // GrantTypesSupported: oidc.SupportedGrantTypes,
- // ClaimsSupported: oidc.SupportedClaims,
- // IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
- // SubjectTypesSupported: []string{"public"},
- // TokenEndpointAuthMethodsSupported:
-}
-
-type endpoints struct {
- Authorization Endpoint
- Token Endpoint
- Introspection Endpoint
- Userinfo Endpoint
- EndSession Endpoint
- CheckSessionIframe Endpoint
- JwksURI Endpoint
-}
-
-type DefaultOPOpts func(o *DefaultOP) error
-
-func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts {
- return func(o *DefaultOP) error {
- if err := endpoint.Validate(); err != nil {
- return err
- }
- o.endpoints.Authorization = endpoint
- return nil
- }
-}
-
-func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts {
- return func(o *DefaultOP) error {
- if err := endpoint.Validate(); err != nil {
- return err
- }
- o.endpoints.Token = endpoint
- return nil
- }
-}
-
-func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
- return func(o *DefaultOP) error {
- if err := endpoint.Validate(); err != nil {
- return err
- }
- o.endpoints.Userinfo = endpoint
- return nil
- }
-}
-
-func WithCustomEndSessionEndpoint(endpoint Endpoint) DefaultOPOpts {
- return func(o *DefaultOP) error {
- if err := endpoint.Validate(); err != nil {
- return err
- }
- o.endpoints.EndSession = endpoint
- return nil
- }
-}
-
-func WithCustomKeysEndpoint(endpoint Endpoint) DefaultOPOpts {
- return func(o *DefaultOP) error {
- if err := endpoint.Validate(); err != nil {
- return err
- }
- o.endpoints.JwksURI = endpoint
- return nil
- }
-}
-
-func WithHttpInterceptor(h HttpInterceptor) DefaultOPOpts {
- return func(o *DefaultOP) error {
- o.interceptor = h
- return nil
- }
-}
-
-func WithRetry(max int, sleep time.Duration) DefaultOPOpts {
- return func(o *DefaultOP) error {
- o.retry = func(count int) (bool, int) {
- count++
- if count == max {
- return false, count
- }
- time.Sleep(sleep)
- return true, count
- }
- return nil
- }
-}
-
-func WithTimer(timer <-chan time.Time) DefaultOPOpts {
- return func(o *DefaultOP) error {
- o.timer = timer
- return nil
- }
-}
-
-func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
- err := ValidateIssuer(config.Issuer)
- if err != nil {
- return nil, err
- }
-
- p := &DefaultOP{
- config: config,
- storage: storage,
- endpoints: DefaultEndpoints,
- timer: make(<-chan time.Time),
- }
-
- for _, optFunc := range opOpts {
- if err := optFunc(p); err != nil {
- return nil, err
- }
- }
-
- keyCh := make(chan jose.SigningKey)
- p.signer = NewDefaultSigner(ctx, storage, keyCh)
- go p.ensureKey(ctx, storage, keyCh, p.timer)
-
- p.verifier = rp.NewDefaultVerifier(config.Issuer, "", p, rp.WithIgnoreAudience(), rp.WithIgnoreExpiration())
-
- p.http = CreateRouter(p, p.interceptor)
-
- p.decoder = schema.NewDecoder()
- p.decoder.IgnoreUnknownKeys(true)
-
- p.encoder = schema.NewEncoder()
-
- p.crypto = NewAESCrypto(config.CryptoKey)
-
- return p, nil
-}
-
-func (p *DefaultOP) Issuer() string {
- return p.config.Issuer
-}
-
-func (p *DefaultOP) AuthorizationEndpoint() Endpoint {
- return p.endpoints.Authorization
-}
-
-func (p *DefaultOP) TokenEndpoint() Endpoint {
- return Endpoint(p.endpoints.Token)
-}
-
-func (p *DefaultOP) UserinfoEndpoint() Endpoint {
- return Endpoint(p.endpoints.Userinfo)
-}
-
-func (p *DefaultOP) EndSessionEndpoint() Endpoint {
- return Endpoint(p.endpoints.EndSession)
-}
-
-func (p *DefaultOP) KeysEndpoint() Endpoint {
- return Endpoint(p.endpoints.JwksURI)
-}
-
-func (p *DefaultOP) AuthMethodPostSupported() bool {
- return true //TODO: config
-}
-
-func (p *DefaultOP) HttpHandler() http.Handler {
- return p.http
-}
-
-func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
- Discover(w, CreateDiscoveryConfig(p, p.Signer()))
-}
-
-func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
- keyID := ""
- for _, sig := range jws.Signatures {
- keyID = sig.Header.KeyID
- break
- }
- keySet, err := p.Storage().GetKeySet(ctx)
- if err != nil {
- return nil, errors.New("error fetching keys")
- }
- payload, err, ok := rp.CheckKey(keyID, keySet.Keys, jws)
- if !ok {
- return nil, errors.New("invalid kid")
- }
- return payload, err
-}
-
-func (p *DefaultOP) Decoder() *schema.Decoder {
- return p.decoder
-}
-
-func (p *DefaultOP) Encoder() *schema.Encoder {
- return p.encoder
-}
-
-func (p *DefaultOP) Storage() Storage {
- return p.storage
-}
-
-func (p *DefaultOP) Signer() Signer {
- return p.signer
-}
-
-func (p *DefaultOP) Crypto() Crypto {
- return p.crypto
-}
-func (p *DefaultOP) HandleReady(w http.ResponseWriter, r *http.Request) {
- probes := []ProbesFn{
- ReadySigner(p.Signer()),
- ReadyStorage(p.Storage()),
- }
- Readiness(w, r, probes...)
-}
-
-func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) {
- Keys(w, r, p)
-}
-
-func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
- Authorize(w, r, p)
-}
-
-func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) {
- AuthorizeCallback(w, r, p)
-}
-
-func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
- reqType := r.FormValue("grant_type")
- if reqType == "" {
- RequestError(w, r, ErrInvalidRequest("grant_type missing"))
- return
- }
- if reqType == string(oidc.GrantTypeCode) {
- CodeExchange(w, r, p)
- return
- }
- TokenExchange(w, r, p)
-}
-
-func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
- Userinfo(w, r, p)
-}
-
-func (p *DefaultOP) HandleEndSession(w http.ResponseWriter, r *http.Request) {
- EndSession(w, r, p)
-}
-
-func (p *DefaultOP) DefaultLogoutRedirectURI() string {
- return p.config.DefaultLogoutRedirectURI
-}
-func (p *DefaultOP) IDTokenVerifier() rp.Verifier {
- return p.verifier
-}
-
-func (p *DefaultOP) ensureKey(ctx context.Context, storage Storage, keyCh chan<- jose.SigningKey, timer <-chan time.Time) {
- count := 0
- timer = time.After(0)
- errCh := make(chan error)
- go storage.GetSigningKey(ctx, keyCh, errCh, timer)
- for {
- select {
- case <-ctx.Done():
- return
- case err := <-errCh:
- if err == nil {
- continue
- }
- _, ok := err.(StorageNotFoundError)
- if ok {
- err := storage.SaveNewKeyPair(ctx)
- if err == nil {
- continue
- }
- }
- ok, count = p.retry(count)
- if ok {
- timer = time.After(0)
- continue
- }
- logging.Log("OP-n6ynVE").WithError(err).Panic("error in key signer")
- }
- }
-}
diff --git a/pkg/op/device.go b/pkg/op/device.go
new file mode 100644
index 0000000..866cbc4
--- /dev/null
+++ b/pkg/op/device.go
@@ -0,0 +1,359 @@
+package op
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "math/big"
+ "net/http"
+ "net/url"
+ "slices"
+ "strings"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type DeviceAuthorizationConfig struct {
+ Lifetime time.Duration
+ PollInterval time.Duration
+
+ // UserFormURL is the complete URL where the user must go to authorize the device.
+ // Deprecated: use UserFormPath instead.
+ UserFormURL string
+
+ // UserFormPath is the path where the user must go to authorize the device.
+ // The hostname for the URL is taken from the request by IssuerFromContext.
+ UserFormPath string
+ UserCode UserCodeConfig
+}
+
+type UserCodeConfig struct {
+ CharSet string
+ CharAmount int
+ DashInterval int
+}
+
+const (
+ CharSetBase20 = "BCDFGHJKLMNPQRSTVWXZ"
+ CharSetDigits = "0123456789"
+)
+
+var (
+ UserCodeBase20 = UserCodeConfig{
+ CharSet: CharSetBase20,
+ CharAmount: 8,
+ DashInterval: 4,
+ }
+ UserCodeDigits = UserCodeConfig{
+ CharSet: CharSetDigits,
+ CharAmount: 9,
+ DashInterval: 3,
+ }
+)
+
+func DeviceAuthorizationHandler(o OpenIDProvider) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if err := DeviceAuthorization(w, r, o); err != nil {
+ RequestError(w, r, err, o.Logger())
+ }
+ }
+}
+
+func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvider) error {
+ ctx, span := tracer.Start(r.Context(), "DeviceAuthorization")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ req, err := ParseDeviceCodeRequest(r, o)
+ if err != nil {
+ return err
+ }
+ response, err := createDeviceAuthorization(r.Context(), req, req.ClientID, o)
+ if err != nil {
+ return err
+ }
+
+ httphelper.MarshalJSON(w, response)
+ return nil
+}
+
+func createDeviceAuthorization(ctx context.Context, req *oidc.DeviceAuthorizationRequest, clientID string, o OpenIDProvider) (*oidc.DeviceAuthorizationResponse, error) {
+ ctx, span := tracer.Start(ctx, "createDeviceAuthorization")
+ defer span.End()
+
+ storage, err := assertDeviceStorage(o.Storage())
+ if err != nil {
+ return nil, err
+ }
+ config := o.DeviceAuthorization()
+
+ deviceCode, _ := NewDeviceCode(RecommendedDeviceCodeBytes)
+ userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
+ if err != nil {
+ return nil, NewStatusError(err, http.StatusInternalServerError)
+ }
+
+ expires := time.Now().Add(config.Lifetime)
+ err = storage.StoreDeviceAuthorization(ctx, clientID, deviceCode, userCode, expires, req.Scopes)
+ if err != nil {
+ return nil, NewStatusError(err, http.StatusInternalServerError)
+ }
+
+ var verification *url.URL
+ if config.UserFormURL != "" {
+ if verification, err = url.Parse(config.UserFormURL); err != nil {
+ err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for device user form")
+ return nil, NewStatusError(err, http.StatusInternalServerError)
+ }
+ } else {
+ if verification, err = url.Parse(IssuerFromContext(ctx)); err != nil {
+ err = oidc.ErrServerError().WithParent(err).WithDescription("invalid URL for issuer")
+ return nil, NewStatusError(err, http.StatusInternalServerError)
+ }
+ verification.Path = config.UserFormPath
+ }
+
+ response := &oidc.DeviceAuthorizationResponse{
+ DeviceCode: deviceCode,
+ UserCode: userCode,
+ VerificationURI: verification.String(),
+ ExpiresIn: int(config.Lifetime / time.Second),
+ Interval: int(config.PollInterval / time.Second),
+ }
+
+ verification.RawQuery = "user_code=" + userCode
+ response.VerificationURIComplete = verification.String()
+ return response, nil
+}
+
+func ParseDeviceCodeRequest(r *http.Request, o OpenIDProvider) (*oidc.DeviceAuthorizationRequest, error) {
+ ctx, span := tracer.Start(r.Context(), "ParseDeviceCodeRequest")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ clientID, _, err := ClientIDFromRequest(r, o)
+ if err != nil {
+ return nil, err
+ }
+ client, err := o.Storage().GetClientByClientID(r.Context(), clientID)
+ if err != nil {
+ return nil, err
+ }
+ if !ValidateGrantType(client, oidc.GrantTypeDeviceCode) {
+ return nil, oidc.ErrUnauthorizedClient().WithDescription("client missing grant type " + string(oidc.GrantTypeCode))
+ }
+
+ req := new(oidc.DeviceAuthorizationRequest)
+ if err := o.Decoder().Decode(req, r.Form); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("cannot parse device authentication request").WithParent(err)
+ }
+ req.ClientID = clientID
+
+ return req, nil
+}
+
+// 16 bytes gives 128 bit of entropy.
+// results in a 22 character base64 encoded string.
+const RecommendedDeviceCodeBytes = 16
+
+// NewDeviceCode generates a new cryptographically secure device code as a base64 encoded string.
+// The length of the string is nBytes * 4 / 3.
+// An error is never returned.
+//
+// TODO(v4): change return type to string alone.
+func NewDeviceCode(nBytes int) (string, error) {
+ bytes := make([]byte, nBytes)
+ rand.Read(bytes)
+ return base64.RawURLEncoding.EncodeToString(bytes), nil
+}
+
+func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
+ var buf strings.Builder
+ if dashInterval > 0 {
+ buf.Grow(charAmount + charAmount/dashInterval - 1)
+ } else {
+ buf.Grow(charAmount)
+ }
+
+ max := big.NewInt(int64(len(charSet)))
+
+ for i := 0; i < charAmount; i++ {
+ if dashInterval != 0 && i != 0 && i%dashInterval == 0 {
+ buf.WriteByte('-')
+ }
+
+ bi, err := rand.Int(rand.Reader, max)
+ if err != nil {
+ return "", fmt.Errorf("%w getting entropy for user code", err)
+ }
+
+ buf.WriteRune(charSet[int(bi.Int64())])
+ }
+
+ return buf.String(), nil
+}
+
+func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "DeviceAccessToken")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ if err := deviceAccessToken(w, r, exchanger); err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+}
+
+func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) error {
+ // use a limited context timeout shorter as the default
+ // poll interval of 5 seconds.
+ ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second)
+ defer cancel()
+ r = r.WithContext(ctx)
+
+ clientID, clientAuthenticated, err := ClientIDFromRequest(r, exchanger)
+ if err != nil {
+ return err
+ }
+
+ req, err := ParseDeviceAccessTokenRequest(r, exchanger)
+ if err != nil {
+ return err
+ }
+ tokenRequest, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
+ if err != nil {
+ return err
+ }
+
+ client, err := exchanger.Storage().GetClientByClientID(ctx, clientID)
+ if err != nil {
+ return err
+ }
+ if clientAuthenticated != IsConfidentialType(client) {
+ return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
+ WithDescription("confidential client requires authentication")
+ }
+
+ resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
+ if err != nil {
+ return err
+ }
+
+ httphelper.MarshalJSON(w, resp)
+ return nil
+}
+
+func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.DeviceAccessTokenRequest, error) {
+ req := new(oidc.DeviceAccessTokenRequest)
+ if err := exchanger.Decoder().Decode(req, r.PostForm); err != nil {
+ return nil, err
+ }
+ return req, nil
+}
+
+// DeviceAuthorizationState describes the current state of
+// the device authorization flow.
+// It implements the [IDTokenRequest] interface.
+type DeviceAuthorizationState struct {
+ ClientID string
+ Audience []string
+ Scopes []string
+ Expires time.Time // The time after we consider the authorization request timed-out
+ Done bool // The user authenticated and approved the authorization request
+ Denied bool // The user authenticated and denied the authorization request
+
+ // The following fields are populated after Done == true
+ Subject string
+ AMR []string
+ AuthTime time.Time
+}
+
+func (r *DeviceAuthorizationState) GetAMR() []string {
+ return r.AMR
+}
+
+func (r *DeviceAuthorizationState) GetAudience() []string {
+ if !slices.Contains(r.Audience, r.ClientID) {
+ r.Audience = append(r.Audience, r.ClientID)
+ }
+ return r.Audience
+}
+
+func (r *DeviceAuthorizationState) GetAuthTime() time.Time {
+ return r.AuthTime
+}
+
+func (r *DeviceAuthorizationState) GetClientID() string {
+ return r.ClientID
+}
+
+func (r *DeviceAuthorizationState) GetScopes() []string {
+ return r.Scopes
+}
+
+func (r *DeviceAuthorizationState) GetSubject() string {
+ return r.Subject
+}
+
+func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
+ ctx, span := tracer.Start(ctx, "CheckDeviceAuthorizationState")
+ defer span.End()
+
+ storage, err := assertDeviceStorage(exchanger.Storage())
+ if err != nil {
+ return nil, err
+ }
+
+ state, err := storage.GetDeviceAuthorizatonState(ctx, clientID, deviceCode)
+ if errors.Is(err, context.DeadlineExceeded) {
+ return nil, oidc.ErrSlowDown().WithParent(err)
+ }
+ if err != nil {
+ return nil, oidc.ErrAccessDenied().WithParent(err)
+ }
+ if state.Denied {
+ return state, oidc.ErrAccessDenied()
+ }
+ if state.Done {
+ return state, nil
+ }
+ if time.Now().After(state.Expires) {
+ return state, oidc.ErrExpiredDeviceCode()
+ }
+ return state, oidc.ErrAuthorizationPending()
+}
+
+func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
+ /* TODO(v4):
+ Change the TokenRequest argument type to *DeviceAuthorizationState.
+ Breaking change that can not be done for v3.
+ */
+ ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse")
+ defer span.End()
+
+ accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
+ if err != nil {
+ return nil, err
+ }
+
+ response := &oidc.AccessTokenResponse{
+ AccessToken: accessToken,
+ RefreshToken: refreshToken,
+ TokenType: oidc.BearerToken,
+ ExpiresIn: uint64(validity.Seconds()),
+ Scope: tokenRequest.GetScopes(),
+ }
+
+ // TODO(v4): remove type assertion
+ if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && slices.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
+ response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return response, nil
+}
diff --git a/pkg/op/device_test.go b/pkg/op/device_test.go
new file mode 100644
index 0000000..a7b5c4e
--- /dev/null
+++ b/pkg/op/device_test.go
@@ -0,0 +1,538 @@
+package op_test
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "io"
+ mr "math/rand"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func Test_deviceAuthorizationHandler(t *testing.T) {
+ type conf struct {
+ UserFormURL string
+ UserFormPath string
+ }
+ tests := []struct {
+ name string
+ conf conf
+ }{
+ {
+ name: "UserFormURL",
+ conf: conf{
+ UserFormURL: "https://localhost:9998/device",
+ },
+ },
+ {
+ name: "UserFormPath",
+ conf: conf{
+ UserFormPath: "/device",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ conf := gu.PtrCopy(testConfig)
+ conf.DeviceAuthorization.UserFormURL = tt.conf.UserFormURL
+ conf.DeviceAuthorization.UserFormPath = tt.conf.UserFormPath
+ provider := newTestProvider(conf)
+
+ req := &oidc.DeviceAuthorizationRequest{
+ Scopes: []string{"foo", "bar"},
+ ClientID: "device",
+ }
+ values := make(url.Values)
+ testProvider.Encoder().Encode(req, values)
+ body := strings.NewReader(values.Encode())
+
+ r := httptest.NewRequest(http.MethodPost, "/", body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ r = r.WithContext(op.ContextWithIssuer(r.Context(), testIssuer))
+
+ w := httptest.NewRecorder()
+
+ runWithRandReader(mr.New(mr.NewSource(1)), func() {
+ op.DeviceAuthorizationHandler(provider)(w, r)
+ })
+
+ result := w.Result()
+
+ assert.Less(t, result.StatusCode, 300)
+
+ got, _ := io.ReadAll(result.Body)
+ assert.JSONEq(t, `{"device_code":"Uv38ByGCZU8WP18PmmIdcg", "expires_in":300, "interval":5, "user_code":"JKRV-FRGK", "verification_uri":"https://localhost:9998/device", "verification_uri_complete":"https://localhost:9998/device?user_code=JKRV-FRGK"}`, string(got))
+ })
+ }
+}
+
+func TestParseDeviceCodeRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ req *oidc.DeviceAuthorizationRequest
+ wantErr bool
+ }{
+ {
+ name: "empty request",
+ wantErr: true,
+ },
+ {
+ name: "missing grant type",
+ req: &oidc.DeviceAuthorizationRequest{
+ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
+ ClientID: "web",
+ },
+ wantErr: true,
+ },
+ {
+ name: "client not found",
+ req: &oidc.DeviceAuthorizationRequest{
+ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
+ ClientID: "foobar",
+ },
+ wantErr: true,
+ },
+ {
+ name: "success",
+ req: &oidc.DeviceAuthorizationRequest{
+ Scopes: oidc.SpaceDelimitedArray{"foo", "bar"},
+ ClientID: "device",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var body io.Reader
+ if tt.req != nil {
+ values := make(url.Values)
+ testProvider.Encoder().Encode(tt.req, values)
+ body = strings.NewReader(values.Encode())
+ }
+
+ r := httptest.NewRequest(http.MethodPost, "/", body)
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ got, err := op.ParseDeviceCodeRequest(r, testProvider)
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+ assert.Equal(t, tt.req, got)
+ })
+ }
+}
+
+func runWithRandReader(r io.Reader, f func()) {
+ originalReader := rand.Reader
+ rand.Reader = r
+ defer func() {
+ rand.Reader = originalReader
+ }()
+
+ f()
+}
+
+func TestNewDeviceCode(t *testing.T) {
+ for i := 1; i <= 32; i++ {
+ got, err := op.NewDeviceCode(i)
+ require.NoError(t, err)
+ assert.Len(t, got, base64.RawURLEncoding.EncodedLen(i))
+ }
+}
+
+func TestNewUserCode(t *testing.T) {
+ type args struct {
+ charset []rune
+ charAmount int
+ dashInterval int
+ }
+ tests := []struct {
+ name string
+ args args
+ reader io.Reader
+ want string
+ wantErr bool
+ }{
+ {
+ name: "reader error",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: errReader{},
+ wantErr: true,
+ },
+ {
+ name: "base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "XKCD-HTTD",
+ },
+ {
+ name: "digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "271-256-225",
+ },
+ {
+ name: "no dashes",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ want: "271256225",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ runWithRandReader(tt.reader, func() {
+ got, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
+ if tt.wantErr {
+ require.ErrorIs(t, err, io.ErrNoProgress)
+ } else {
+ require.NoError(t, err)
+ }
+ assert.Equal(t, tt.want, got)
+ })
+
+ })
+ }
+
+ t.Run("crypto/rand", func(t *testing.T) {
+ const testN = 100000
+
+ for _, c := range []op.UserCodeConfig{op.UserCodeBase20, op.UserCodeDigits} {
+ t.Run(c.CharSet, func(t *testing.T) {
+ results := make(map[string]int)
+
+ for i := 0; i < testN; i++ {
+ code, err := op.NewUserCode([]rune(c.CharSet), c.CharAmount, c.DashInterval)
+ require.NoError(t, err)
+ results[code]++
+ }
+
+ t.Log(results)
+
+ var duplicates int
+ for code, count := range results {
+ assert.Less(t, count, 3, code)
+ if count == 2 {
+ duplicates++
+ }
+ }
+
+ })
+ }
+ })
+}
+
+func BenchmarkNewUserCode(b *testing.B) {
+ type args struct {
+ charset []rune
+ charAmount int
+ dashInterval int
+ }
+ tests := []struct {
+ name string
+ args args
+ reader io.Reader
+ }{
+ {
+ name: "math rand, base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ },
+ {
+ name: "math rand, digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: mr.New(mr.NewSource(1)),
+ },
+ {
+ name: "crypto rand, base20",
+ args: args{
+ charset: []rune(op.CharSetBase20),
+ charAmount: 8,
+ dashInterval: 4,
+ },
+ reader: rand.Reader,
+ },
+ {
+ name: "crypto rand, digits",
+ args: args{
+ charset: []rune(op.CharSetDigits),
+ charAmount: 9,
+ dashInterval: 3,
+ },
+ reader: rand.Reader,
+ },
+ }
+ for _, tt := range tests {
+ runWithRandReader(tt.reader, func() {
+ b.Run(tt.name, func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ _, err := op.NewUserCode(tt.args.charset, tt.args.charAmount, tt.args.dashInterval)
+ require.NoError(b, err)
+ }
+ })
+
+ })
+ }
+}
+
+func TestDeviceAccessToken(t *testing.T) {
+ storage := testProvider.Storage().(*storage.Storage)
+ storage.StoreDeviceAuthorization(context.Background(), "native", "qwerty", "yuiop", time.Now().Add(time.Minute), []string{"foo"})
+ storage.CompleteDeviceAuthorization(context.Background(), "yuiop", "tim")
+
+ values := make(url.Values)
+ values.Set("client_id", "native")
+ values.Set("grant_type", string(oidc.GrantTypeDeviceCode))
+ values.Set("device_code", "qwerty")
+
+ r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ w := httptest.NewRecorder()
+
+ op.DeviceAccessToken(w, r, testProvider)
+
+ result := w.Result()
+ got, _ := io.ReadAll(result.Body)
+ t.Log(string(got))
+ assert.Less(t, result.StatusCode, 300)
+ assert.NotEmpty(t, string(got))
+}
+
+func TestCheckDeviceAuthorizationState(t *testing.T) {
+ now := time.Now()
+
+ storage := testProvider.Storage().(*storage.Storage)
+ storage.StoreDeviceAuthorization(context.Background(), "native", "pending", "pending", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "denied", "denied", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "completed", "completed", now.Add(time.Minute), []string{"foo"})
+ storage.StoreDeviceAuthorization(context.Background(), "native", "expired", "expired", now.Add(-time.Minute), []string{"foo"})
+
+ storage.DenyDeviceAuthorization(context.Background(), "denied")
+ storage.CompleteDeviceAuthorization(context.Background(), "completed", "tim")
+
+ exceededCtx, cancel := context.WithTimeout(context.Background(), -time.Second)
+ defer cancel()
+
+ type args struct {
+ ctx context.Context
+ clientID string
+ deviceCode string
+ }
+ tests := []struct {
+ name string
+ args args
+ want *op.DeviceAuthorizationState
+ wantErr error
+ }{
+ {
+ name: "pending",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "pending",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ },
+ wantErr: oidc.ErrAuthorizationPending(),
+ },
+ {
+ name: "slow down",
+ args: args{
+ ctx: exceededCtx,
+ clientID: "native",
+ deviceCode: "ok",
+ },
+ wantErr: oidc.ErrSlowDown(),
+ },
+ {
+ name: "wrong client",
+ args: args{
+ ctx: context.Background(),
+ clientID: "foo",
+ deviceCode: "ok",
+ },
+ wantErr: oidc.ErrAccessDenied(),
+ },
+ {
+ name: "denied",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "denied",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ Denied: true,
+ },
+ wantErr: oidc.ErrAccessDenied(),
+ },
+ {
+ name: "completed",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "completed",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(time.Minute),
+ Subject: "tim",
+ Done: true,
+ },
+ },
+ {
+ name: "expired",
+ args: args{
+ ctx: context.Background(),
+ clientID: "native",
+ deviceCode: "expired",
+ },
+ want: &op.DeviceAuthorizationState{
+ ClientID: "native",
+ Scopes: []string{"foo"},
+ Expires: now.Add(-time.Minute),
+ },
+ wantErr: oidc.ErrExpiredDeviceCode(),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := op.CheckDeviceAuthorizationState(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, testProvider)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestCreateDeviceTokenResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ tokenRequest op.TokenRequest
+ wantAccessToken bool
+ wantRefreshToken bool
+ wantIDToken bool
+ wantErr bool
+ }{
+ {
+ name: "access token",
+ tokenRequest: &op.DeviceAuthorizationState{
+ ClientID: "client1",
+ Subject: "id1",
+ AMR: []string{"password"},
+ AuthTime: time.Now(),
+ },
+ wantAccessToken: true,
+ },
+ {
+ name: "access and refresh tokens",
+ tokenRequest: &op.DeviceAuthorizationState{
+ ClientID: "client1",
+ Subject: "id1",
+ AMR: []string{"password"},
+ AuthTime: time.Now(),
+ Scopes: []string{oidc.ScopeOfflineAccess},
+ },
+ wantAccessToken: true,
+ wantRefreshToken: true,
+ },
+ {
+ name: "access and id token",
+ tokenRequest: &op.DeviceAuthorizationState{
+ ClientID: "client1",
+ Subject: "id1",
+ AMR: []string{"password"},
+ AuthTime: time.Now(),
+ Scopes: []string{oidc.ScopeOpenID},
+ },
+ wantAccessToken: true,
+ wantIDToken: true,
+ },
+ {
+ name: "access, refresh and id token",
+ tokenRequest: &op.DeviceAuthorizationState{
+ ClientID: "client1",
+ Subject: "id1",
+ AMR: []string{"password"},
+ AuthTime: time.Now(),
+ Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
+ },
+ wantAccessToken: true,
+ wantRefreshToken: true,
+ wantIDToken: true,
+ },
+ {
+ name: "id token creation error",
+ tokenRequest: &op.DeviceAuthorizationState{
+ ClientID: "client1",
+ Subject: "foobar",
+ AMR: []string{"password"},
+ AuthTime: time.Now(),
+ Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native")
+ require.NoError(t, err)
+
+ got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client)
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.InDelta(t, 300, got.ExpiresIn, 2)
+ if tt.wantAccessToken {
+ assert.NotEmpty(t, got.AccessToken, "access token")
+ }
+ if tt.wantRefreshToken {
+ assert.NotEmpty(t, got.RefreshToken, "refresh token")
+ }
+ if tt.wantIDToken {
+ assert.NotEmpty(t, got.IDToken, "id token")
+ }
+ })
+ }
+}
diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go
index fd6e0a6..9b3ddb6 100644
--- a/pkg/op/discovery.go
+++ b/pkg/op/discovery.go
@@ -1,119 +1,242 @@
package op
import (
+ "context"
"net/http"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/utils"
+ jose "github.com/go-jose/go-jose/v4"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
-func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
- utils.MarshalJSON(w, config)
+type DiscoverStorage interface {
+ SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
}
-func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
- return &oidc.DiscoveryConfiguration{
- Issuer: c.Issuer(),
- AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
- TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
- // IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
- UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
- EndSessionEndpoint: c.EndSessionEndpoint().Absolute(c.Issuer()),
- // CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
- JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
- ScopesSupported: Scopes(c),
- ResponseTypesSupported: ResponseTypes(c),
- GrantTypesSupported: GrantTypes(c),
- ClaimsSupported: SupportedClaims(c),
- IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
- SubjectTypesSupported: SubjectTypes(c),
- TokenEndpointAuthMethodsSupported: AuthMethods(c),
+var DefaultSupportedScopes = []string{
+ oidc.ScopeOpenID,
+ oidc.ScopeProfile,
+ oidc.ScopeEmail,
+ oidc.ScopePhone,
+ oidc.ScopeAddress,
+ oidc.ScopeOfflineAccess,
+}
+
+func discoveryHandler(c Configuration, s DiscoverStorage) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Discover(w, CreateDiscoveryConfig(r.Context(), c, s))
}
}
-const (
- ScopeOpenID = "openid"
- ScopeProfile = "profile"
- ScopeEmail = "email"
- ScopePhone = "phone"
- ScopeAddress = "address"
-)
+func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
+ httphelper.MarshalJSON(w, config)
+}
-var DefaultSupportedScopes = []string{
- ScopeOpenID,
- ScopeProfile,
- ScopeEmail,
- ScopePhone,
- ScopeAddress,
+func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage DiscoverStorage) *oidc.DiscoveryConfiguration {
+ issuer := IssuerFromContext(ctx)
+ return &oidc.DiscoveryConfiguration{
+ Issuer: issuer,
+ AuthorizationEndpoint: config.AuthorizationEndpoint().Absolute(issuer),
+ TokenEndpoint: config.TokenEndpoint().Absolute(issuer),
+ IntrospectionEndpoint: config.IntrospectionEndpoint().Absolute(issuer),
+ UserinfoEndpoint: config.UserinfoEndpoint().Absolute(issuer),
+ RevocationEndpoint: config.RevocationEndpoint().Absolute(issuer),
+ EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
+ JwksURI: config.KeysEndpoint().Absolute(issuer),
+ DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
+ CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer),
+ ScopesSupported: Scopes(config),
+ ResponseTypesSupported: ResponseTypes(config),
+ GrantTypesSupported: GrantTypes(config),
+ SubjectTypesSupported: SubjectTypes(config),
+ IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
+ RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
+ TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
+ TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
+ IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
+ IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
+ RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
+ RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
+ ClaimsSupported: SupportedClaims(config),
+ CodeChallengeMethodsSupported: CodeChallengeMethods(config),
+ UILocalesSupported: config.SupportedUILocales(),
+ RequestParameterSupported: config.RequestObjectSupported(),
+ BackChannelLogoutSupported: config.BackChannelLogoutSupported(),
+ BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(),
+ }
+}
+
+func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage DiscoverStorage, endpoints *Endpoints) *oidc.DiscoveryConfiguration {
+ issuer := IssuerFromContext(ctx)
+ return &oidc.DiscoveryConfiguration{
+ Issuer: issuer,
+ AuthorizationEndpoint: endpoints.Authorization.Absolute(issuer),
+ TokenEndpoint: endpoints.Token.Absolute(issuer),
+ IntrospectionEndpoint: endpoints.Introspection.Absolute(issuer),
+ UserinfoEndpoint: endpoints.Userinfo.Absolute(issuer),
+ RevocationEndpoint: endpoints.Revocation.Absolute(issuer),
+ EndSessionEndpoint: endpoints.EndSession.Absolute(issuer),
+ JwksURI: endpoints.JwksURI.Absolute(issuer),
+ DeviceAuthorizationEndpoint: endpoints.DeviceAuthorization.Absolute(issuer),
+ ScopesSupported: Scopes(config),
+ ResponseTypesSupported: ResponseTypes(config),
+ GrantTypesSupported: GrantTypes(config),
+ SubjectTypesSupported: SubjectTypes(config),
+ IDTokenSigningAlgValuesSupported: SigAlgorithms(ctx, storage),
+ RequestObjectSigningAlgValuesSupported: RequestObjectSigAlgorithms(config),
+ TokenEndpointAuthMethodsSupported: AuthMethodsTokenEndpoint(config),
+ TokenEndpointAuthSigningAlgValuesSupported: TokenSigAlgorithms(config),
+ IntrospectionEndpointAuthSigningAlgValuesSupported: IntrospectionSigAlgorithms(config),
+ IntrospectionEndpointAuthMethodsSupported: AuthMethodsIntrospectionEndpoint(config),
+ RevocationEndpointAuthSigningAlgValuesSupported: RevocationSigAlgorithms(config),
+ RevocationEndpointAuthMethodsSupported: AuthMethodsRevocationEndpoint(config),
+ ClaimsSupported: SupportedClaims(config),
+ CodeChallengeMethodsSupported: CodeChallengeMethods(config),
+ UILocalesSupported: config.SupportedUILocales(),
+ RequestParameterSupported: config.RequestObjectSupported(),
+ BackChannelLogoutSupported: config.BackChannelLogoutSupported(),
+ BackChannelLogoutSessionSupported: config.BackChannelLogoutSessionSupported(),
+ }
}
func Scopes(c Configuration) []string {
- return DefaultSupportedScopes //TODO: config
+ provider, ok := c.(*Provider)
+ if ok && provider.config.SupportedScopes != nil {
+ return provider.config.SupportedScopes
+ }
+ return DefaultSupportedScopes
}
func ResponseTypes(c Configuration) []string {
return []string{
- "code",
- "id_token",
- // "code token",
- // "code id_token",
- "id_token token",
- // "code id_token token"
- }
+ string(oidc.ResponseTypeCode),
+ string(oidc.ResponseTypeIDTokenOnly),
+ string(oidc.ResponseTypeIDToken),
+ } // TODO: ok for now, check later if dynamic needed
}
-func GrantTypes(c Configuration) []string {
- return []string{
- "client_credentials",
- "authorization_code",
- // "password",
- "urn:ietf:params:oauth:grant-type:token-exchange",
+func GrantTypes(c Configuration) []oidc.GrantType {
+ grantTypes := []oidc.GrantType{
+ oidc.GrantTypeCode,
+ oidc.GrantTypeImplicit,
}
-}
-
-func SupportedClaims(c Configuration) []string {
- return []string{ //TODO: config
- "sub",
- "aud",
- "exp",
- "iat",
- "iss",
- "auth_time",
- "nonce",
- "acr",
- "amr",
- "c_hash",
- "at_hash",
- "act",
- "scopes",
- "client_id",
- "azp",
- "preferred_username",
- "name",
- "family_name",
- "given_name",
- "locale",
- "email",
- "email_verified",
- "phone_number",
- "phone_number_verified",
+ if c.GrantTypeRefreshTokenSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeRefreshToken)
}
-}
-
-func SigAlgorithms(s Signer) []string {
- return []string{string(s.SignatureAlgorithm())}
+ if c.GrantTypeClientCredentialsSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeClientCredentials)
+ }
+ if c.GrantTypeTokenExchangeSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeTokenExchange)
+ }
+ if c.GrantTypeJWTAuthorizationSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeBearer)
+ }
+ if c.GrantTypeDeviceCodeSupported() {
+ grantTypes = append(grantTypes, oidc.GrantTypeDeviceCode)
+ }
+ return grantTypes
}
func SubjectTypes(c Configuration) []string {
- return []string{"public"} //TODO: config
+ return []string{"public"} // TODO: config
}
-func AuthMethods(c Configuration) []string {
- authMethods := []string{
- string(AuthMethodBasic),
+func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string {
+ ctx, span := tracer.Start(ctx, "SigAlgorithms")
+ defer span.End()
+
+ algorithms, err := storage.SignatureAlgorithms(ctx)
+ if err != nil {
+ return nil
+ }
+ algs := make([]string, len(algorithms))
+ for i, algorithm := range algorithms {
+ algs[i] = string(algorithm)
+ }
+ return algs
+}
+
+func RequestObjectSigAlgorithms(c Configuration) []string {
+ if !c.RequestObjectSupported() {
+ return nil
+ }
+ return c.RequestObjectSigningAlgorithmsSupported()
+}
+
+func AuthMethodsTokenEndpoint(c Configuration) []oidc.AuthMethod {
+ authMethods := []oidc.AuthMethod{
+ oidc.AuthMethodNone,
+ oidc.AuthMethodBasic,
}
if c.AuthMethodPostSupported() {
- authMethods = append(authMethods, string(AuthMethodPost))
+ authMethods = append(authMethods, oidc.AuthMethodPost)
+ }
+ if c.AuthMethodPrivateKeyJWTSupported() {
+ authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
}
return authMethods
}
+
+func TokenSigAlgorithms(c Configuration) []string {
+ if !c.AuthMethodPrivateKeyJWTSupported() {
+ return nil
+ }
+ return c.TokenEndpointSigningAlgorithmsSupported()
+}
+
+func IntrospectionSigAlgorithms(c Configuration) []string {
+ if !c.IntrospectionAuthMethodPrivateKeyJWTSupported() {
+ return nil
+ }
+ return c.IntrospectionEndpointSigningAlgorithmsSupported()
+}
+
+func AuthMethodsIntrospectionEndpoint(c Configuration) []oidc.AuthMethod {
+ authMethods := []oidc.AuthMethod{
+ oidc.AuthMethodBasic,
+ }
+ if c.AuthMethodPrivateKeyJWTSupported() {
+ authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
+ }
+ return authMethods
+}
+
+func RevocationSigAlgorithms(c Configuration) []string {
+ if !c.RevocationAuthMethodPrivateKeyJWTSupported() {
+ return nil
+ }
+ return c.RevocationEndpointSigningAlgorithmsSupported()
+}
+
+func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
+ authMethods := []oidc.AuthMethod{
+ oidc.AuthMethodNone,
+ oidc.AuthMethodBasic,
+ }
+ if c.AuthMethodPostSupported() {
+ authMethods = append(authMethods, oidc.AuthMethodPost)
+ }
+ if c.AuthMethodPrivateKeyJWTSupported() {
+ authMethods = append(authMethods, oidc.AuthMethodPrivateKeyJWT)
+ }
+ return authMethods
+}
+
+func SupportedClaims(c Configuration) []string {
+ provider, ok := c.(*Provider)
+ if ok && provider.config.SupportedClaims != nil {
+ return provider.config.SupportedClaims
+ }
+
+ return DefaultSupportedClaims
+}
+
+func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
+ codeMethods := make([]oidc.CodeChallengeMethod, 0, 1)
+ if c.CodeMethodS256Supported() {
+ codeMethods = append(codeMethods, oidc.CodeChallengeMethodS256)
+ }
+ return codeMethods
+}
diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go
index 39b39bc..63f1b98 100644
--- a/pkg/op/discovery_test.go
+++ b/pkg/op/discovery_test.go
@@ -1,17 +1,19 @@
package op_test
import (
+ "context"
"net/http"
"net/http/httptest"
- "reflect"
"testing"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/op"
- "github.com/caos/oidc/pkg/op/mock"
+ jose "github.com/go-jose/go-jose/v4"
"github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "gopkg.in/square/go-jose.v2"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock"
)
func TestDiscover(t *testing.T) {
@@ -36,15 +38,19 @@ func TestDiscover(t *testing.T) {
op.Discover(tt.args.w, tt.args.config)
rec := tt.args.w.(*httptest.ResponseRecorder)
require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
+ require.Equal(t,
+ `{"issuer":"https://issuer.com","request_uri_parameter_supported":false}
+`,
+ rec.Body.String())
})
}
}
func TestCreateDiscoveryConfig(t *testing.T) {
type args struct {
- c op.Configuration
- s op.Signer
+ ctx context.Context
+ c op.Configuration
+ s op.DiscoverStorage
}
tests := []struct {
name string
@@ -55,9 +61,8 @@ func TestCreateDiscoveryConfig(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
- }
+ got := op.CreateDiscoveryConfig(tt.args.ctx, tt.args.c, tt.args.s)
+ assert.Equal(t, tt.want, got)
})
}
}
@@ -76,12 +81,16 @@ func Test_scopes(t *testing.T) {
args{},
op.DefaultSupportedScopes,
},
+ {
+ "custom scopes",
+ args{newTestProvider(&op.Config{SupportedScopes: []string{"test1", "test2"}})},
+ []string{"test1", "test2"},
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("scopes() = %v, want %v", got, tt.want)
- }
+ got := op.Scopes(tt.args.c)
+ assert.Equal(t, tt.want, got)
})
}
}
@@ -95,13 +104,16 @@ func Test_ResponseTypes(t *testing.T) {
args args
want []string
}{
- // TODO: Add test cases.
+ {
+ "code and implicit flow",
+ args{},
+ []string{"code", "id_token", "id_token token"},
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("responseTypes() = %v, want %v", got, tt.want)
- }
+ got := op.ResponseTypes(tt.args.c)
+ assert.Equal(t, tt.want, got)
})
}
}
@@ -113,63 +125,53 @@ func Test_GrantTypes(t *testing.T) {
tests := []struct {
name string
args args
- want []string
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("grantTypes() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestSupportedClaims(t *testing.T) {
- type args struct {
- c op.Configuration
- }
- tests := []struct {
- name string
- args args
- want []string
- }{
- // TODO: Add test cases.
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Test_SigAlgorithms(t *testing.T) {
- m := mock.NewMockSigner(gomock.NewController((t)))
- type args struct {
- s op.Signer
- }
- tests := []struct {
- name string
- args args
- want []string
+ want []oidc.GrantType
}{
{
- "",
- args{func() op.Signer {
- m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
- return m
- }()},
- []string{"RS256"},
+ "code and implicit flow",
+ args{
+ func() op.Configuration {
+ c := mock.NewMockConfiguration(gomock.NewController(t))
+ c.EXPECT().GrantTypeRefreshTokenSupported().Return(false)
+ c.EXPECT().GrantTypeTokenExchangeSupported().Return(false)
+ c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(false)
+ c.EXPECT().GrantTypeClientCredentialsSupported().Return(false)
+ c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
+ return c
+ }(),
+ },
+ []oidc.GrantType{
+ oidc.GrantTypeCode,
+ oidc.GrantTypeImplicit,
+ },
+ },
+ {
+ "code, implicit flow, refresh token, token exchange, jwt profile, client_credentials",
+ args{
+ func() op.Configuration {
+ c := mock.NewMockConfiguration(gomock.NewController(t))
+ c.EXPECT().GrantTypeRefreshTokenSupported().Return(true)
+ c.EXPECT().GrantTypeTokenExchangeSupported().Return(true)
+ c.EXPECT().GrantTypeJWTAuthorizationSupported().Return(true)
+ c.EXPECT().GrantTypeClientCredentialsSupported().Return(true)
+ c.EXPECT().GrantTypeDeviceCodeSupported().Return(false)
+ return c
+ }(),
+ },
+ []oidc.GrantType{
+ oidc.GrantTypeCode,
+ oidc.GrantTypeImplicit,
+ oidc.GrantTypeRefreshToken,
+ oidc.GrantTypeClientCredentials,
+ oidc.GrantTypeTokenExchange,
+ oidc.GrantTypeBearer,
+ },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
- }
+ got := op.GrantTypes(tt.args.c)
+ assert.Equal(t, tt.want, got)
})
}
}
@@ -191,15 +193,41 @@ func Test_SubjectTypes(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
- }
+ got := op.SubjectTypes(tt.args.c)
+ assert.Equal(t, tt.want, got)
})
}
}
-func Test_AuthMethods(t *testing.T) {
- m := mock.NewMockConfiguration(gomock.NewController((t)))
+func Test_SigAlgorithms(t *testing.T) {
+ m := mock.NewMockDiscoverStorage(gomock.NewController(t))
+ type args struct {
+ s op.DiscoverStorage
+ }
+ tests := []struct {
+ name string
+ args args
+ want []string
+ }{
+ {
+ "",
+ args{func() op.DiscoverStorage {
+ m.EXPECT().SignatureAlgorithms(gomock.Any()).Return([]jose.SignatureAlgorithm{jose.RS256}, nil)
+ return m
+ }()},
+ []string{"RS256"},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.SigAlgorithms(context.Background(), tt.args.s)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_RequestObjectSigAlgorithms(t *testing.T) {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
type args struct {
c op.Configuration
}
@@ -209,27 +237,387 @@ func Test_AuthMethods(t *testing.T) {
want []string
}{
{
- "imlicit basic",
+ "not supported, empty",
args{func() op.Configuration {
- m.EXPECT().AuthMethodPostSupported().Return(false)
+ m.EXPECT().RequestObjectSupported().Return(false)
return m
}()},
- []string{string(op.AuthMethodBasic)},
+ nil,
},
{
- "basic and post",
+ "supported, empty",
args{func() op.Configuration {
- m.EXPECT().AuthMethodPostSupported().Return(true)
+ m.EXPECT().RequestObjectSupported().Return(true)
+ m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return(nil)
return m
}()},
- []string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
+ nil,
+ },
+ {
+ "supported, list",
+ args{func() op.Configuration {
+ m.EXPECT().RequestObjectSupported().Return(true)
+ m.EXPECT().RequestObjectSigningAlgorithmsSupported().Return([]string{"RS256"})
+ return m
+ }()},
+ []string{"RS256"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := op.AuthMethods(tt.args.c); !reflect.DeepEqual(got, tt.want) {
- t.Errorf("authMethods() = %v, want %v", got, tt.want)
- }
+ got := op.RequestObjectSigAlgorithms(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_AuthMethodsTokenEndpoint(t *testing.T) {
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []oidc.AuthMethod
+ }{
+ {
+ "none and basic",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(false)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic},
+ },
+ {
+ "none, basic and post",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(true)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost},
+ },
+ {
+ "none, basic, post and private_key_jwt",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(true)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost, oidc.AuthMethodPrivateKeyJWT},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.AuthMethodsTokenEndpoint(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_TokenSigAlgorithms(t *testing.T) {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []string
+ }{
+ {
+ "not supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return(nil)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, list",
+ args{func() op.Configuration {
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().TokenEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
+ return m
+ }()},
+ []string{"RS256"},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.TokenSigAlgorithms(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_IntrospectionSigAlgorithms(t *testing.T) {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []string
+ }{
+ {
+ "not supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return(nil)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, list",
+ args{func() op.Configuration {
+ m.EXPECT().IntrospectionAuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().IntrospectionEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
+ return m
+ }()},
+ []string{"RS256"},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.IntrospectionSigAlgorithms(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_AuthMethodsIntrospectionEndpoint(t *testing.T) {
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []oidc.AuthMethod
+ }{
+ {
+ "basic only",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodBasic},
+ },
+ {
+ "basic and private_key_jwt",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodBasic, oidc.AuthMethodPrivateKeyJWT},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.AuthMethodsIntrospectionEndpoint(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_RevocationSigAlgorithms(t *testing.T) {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []string
+ }{
+ {
+ "not supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, empty",
+ args{func() op.Configuration {
+ m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return(nil)
+ return m
+ }()},
+ nil,
+ },
+ {
+ "supported, list",
+ args{func() op.Configuration {
+ m.EXPECT().RevocationAuthMethodPrivateKeyJWTSupported().Return(true)
+ m.EXPECT().RevocationEndpointSigningAlgorithmsSupported().Return([]string{"RS256"})
+ return m
+ }()},
+ []string{"RS256"},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.RevocationSigAlgorithms(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_AuthMethodsRevocationEndpoint(t *testing.T) {
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []oidc.AuthMethod
+ }{
+ {
+ "none and basic",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(false)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic},
+ },
+ {
+ "none, basic and post",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(true)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(false)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost},
+ },
+ {
+ "none, basic, post and private_key_jwt",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().AuthMethodPostSupported().Return(true)
+ m.EXPECT().AuthMethodPrivateKeyJWTSupported().Return(true)
+ return m
+ }()},
+ []oidc.AuthMethod{oidc.AuthMethodNone, oidc.AuthMethodBasic, oidc.AuthMethodPost, oidc.AuthMethodPrivateKeyJWT},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.AuthMethodsRevocationEndpoint(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestSupportedClaims(t *testing.T) {
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []string
+ }{
+ {
+ "scopes",
+ args{},
+ []string{
+ "sub",
+ "aud",
+ "exp",
+ "iat",
+ "iss",
+ "auth_time",
+ "nonce",
+ "acr",
+ "amr",
+ "c_hash",
+ "at_hash",
+ "act",
+ "scopes",
+ "client_id",
+ "azp",
+ "preferred_username",
+ "name",
+ "family_name",
+ "given_name",
+ "locale",
+ "email",
+ "email_verified",
+ "phone_number",
+ "phone_number_verified",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.SupportedClaims(tt.args.c)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_CodeChallengeMethods(t *testing.T) {
+ type args struct {
+ c op.Configuration
+ }
+ tests := []struct {
+ name string
+ args args
+ want []oidc.CodeChallengeMethod
+ }{
+ {
+ "not supported",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().CodeMethodS256Supported().Return(false)
+ return m
+ }()},
+ []oidc.CodeChallengeMethod{},
+ },
+ {
+ "S256",
+ args{func() op.Configuration {
+ m := mock.NewMockConfiguration(gomock.NewController(t))
+ m.EXPECT().CodeMethodS256Supported().Return(true)
+ return m
+ }()},
+ []oidc.CodeChallengeMethod{oidc.CodeChallengeMethodS256},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := op.CodeChallengeMethods(tt.args.c)
+ assert.Equal(t, tt.want, got)
})
}
}
diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go
index 21907f4..1ac1cad 100644
--- a/pkg/op/endpoint.go
+++ b/pkg/op/endpoint.go
@@ -1,33 +1,47 @@
package op
-import "strings"
+import (
+ "errors"
+ "strings"
+)
type Endpoint struct {
path string
url string
}
-func NewEndpoint(path string) Endpoint {
- return Endpoint{path: path}
+func NewEndpoint(path string) *Endpoint {
+ return &Endpoint{path: path}
}
-func NewEndpointWithURL(path, url string) Endpoint {
- return Endpoint{path: path, url: url}
+func NewEndpointWithURL(path, url string) *Endpoint {
+ return &Endpoint{path: path, url: url}
}
-func (e Endpoint) Relative() string {
+func (e *Endpoint) Relative() string {
+ if e == nil {
+ return ""
+ }
return relativeEndpoint(e.path)
}
-func (e Endpoint) Absolute(host string) string {
+func (e *Endpoint) Absolute(host string) string {
+ if e == nil {
+ return ""
+ }
if e.url != "" {
return e.url
}
return absoluteEndpoint(host, e.path)
}
-func (e Endpoint) Validate() error {
- return nil //TODO:
+var ErrNilEndpoint = errors.New("nil endpoint")
+
+func (e *Endpoint) Validate() error {
+ if e == nil {
+ return ErrNilEndpoint
+ }
+ return nil // TODO:
}
func absoluteEndpoint(host, endpoint string) string {
diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go
index fe00326..5b98c6e 100644
--- a/pkg/op/endpoint_test.go
+++ b/pkg/op/endpoint_test.go
@@ -3,13 +3,14 @@ package op_test
import (
"testing"
- "github.com/caos/oidc/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/stretchr/testify/require"
)
func TestEndpoint_Path(t *testing.T) {
tests := []struct {
name string
- e op.Endpoint
+ e *op.Endpoint
want string
}{
{
@@ -27,6 +28,11 @@ func TestEndpoint_Path(t *testing.T) {
op.NewEndpointWithURL("/test", "http://test.com/test"),
"/test",
},
+ {
+ "nil",
+ nil,
+ "",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -43,7 +49,7 @@ func TestEndpoint_Absolute(t *testing.T) {
}
tests := []struct {
name string
- e op.Endpoint
+ e *op.Endpoint
args args
want string
}{
@@ -77,6 +83,12 @@ func TestEndpoint_Absolute(t *testing.T) {
args{"https://host"},
"https://test.com/test",
},
+ {
+ "nil",
+ nil,
+ args{"https://host"},
+ "",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -87,20 +99,23 @@ func TestEndpoint_Absolute(t *testing.T) {
}
}
-//TODO: impl test
+// TODO: impl test
func TestEndpoint_Validate(t *testing.T) {
tests := []struct {
name string
- e op.Endpoint
- wantErr bool
+ e *op.Endpoint
+ wantErr error
}{
- // TODO: Add test cases.
+ {
+ "nil",
+ nil,
+ op.ErrNilEndpoint,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if err := tt.e.Validate(); (err != nil) != tt.wantErr {
- t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
- }
+ err := tt.e.Validate()
+ require.ErrorIs(t, err, tt.wantErr)
})
}
}
diff --git a/pkg/op/error.go b/pkg/op/error.go
index f3c5857..272f85e 100644
--- a/pkg/op/error.go
+++ b/pkg/op/error.go
@@ -1,99 +1,197 @@
package op
import (
+ "context"
+ "errors"
"fmt"
+ "log/slog"
"net/http"
- "github.com/gorilla/schema"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/utils"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
-const (
- InvalidRequest errorType = "invalid_request"
- ServerError errorType = "server_error"
-)
-
-var (
- ErrInvalidRequest = func(description string) *OAuthError {
- return &OAuthError{
- ErrorType: InvalidRequest,
- Description: description,
- }
- }
- ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
- return &OAuthError{
- ErrorType: InvalidRequest,
- Description: description,
- redirectDisabled: true,
- }
- }
- ErrServerError = func(description string) *OAuthError {
- return &OAuthError{
- ErrorType: ServerError,
- Description: description,
- }
- }
-)
-
-type errorType string
-
type ErrAuthRequest interface {
GetRedirectURI() string
GetResponseType() oidc.ResponseType
GetState() string
}
-func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
+// LogAuthRequest is an optional interface,
+// that allows logging AuthRequest fields.
+// If the AuthRequest does not implement this interface,
+// no details shall be printed to the logs.
+type LogAuthRequest interface {
+ ErrAuthRequest
+ slog.LogValuer
+}
+
+func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
+ e := oidc.DefaultToServerError(err, err.Error())
+ logger := authorizer.Logger().With("oidc_error", e)
+
if authReq == nil {
+ logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- e, ok := err.(*OAuthError)
- if !ok {
- e = new(OAuthError)
- e.ErrorType = ServerError
- e.Description = err.Error()
+
+ if logAuthReq, ok := authReq.(LogAuthRequest); ok {
+ logger = logger.With("auth_request", logAuthReq)
}
- e.state = authReq.GetState()
- if authReq.GetRedirectURI() == "" || e.redirectDisabled {
+
+ if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
+ logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
http.Error(w, e.Description, http.StatusBadRequest)
return
}
- params, err := utils.URLEncodeResponse(e, encoder)
+ e.State = authReq.GetState()
+ var sessionState string
+ authRequestSessionState, ok := authReq.(AuthRequestSessionState)
+ if ok {
+ sessionState = authRequestSessionState.GetSessionState()
+ }
+ e.SessionState = sessionState
+ var responseMode oidc.ResponseMode
+ if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
+ responseMode = rm.GetResponseMode()
+ }
+ url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
if err != nil {
+ logger.ErrorContext(r.Context(), "auth response URL", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- url := authReq.GetRedirectURI()
- responseType := authReq.GetResponseType()
- if responseType == "" || responseType == oidc.ResponseTypeCode {
- url += "?" + params
- } else {
- url += "#" + params
- }
+ logger.Log(r.Context(), e.LogLevel(), "auth request")
http.Redirect(w, r, url, http.StatusFound)
}
-func RequestError(w http.ResponseWriter, r *http.Request, err error) {
- e, ok := err.(*OAuthError)
- if !ok {
- e = new(OAuthError)
- e.ErrorType = ServerError
- e.Description = err.Error()
+func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
+ e := oidc.DefaultToServerError(err, err.Error())
+ status := http.StatusBadRequest
+ if e.ErrorType == oidc.InvalidClient {
+ status = http.StatusUnauthorized
}
- w.WriteHeader(http.StatusBadRequest)
- utils.MarshalJSON(w, e)
+ logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
+ httphelper.MarshalJSONWithStatus(w, e, status)
}
-type OAuthError struct {
- ErrorType errorType `json:"error" schema:"error"`
- Description string `json:"error_description" schema:"error_description"`
- state string `json:"state" schema:"state"`
- redirectDisabled bool
+// TryErrorRedirect tries to handle an error by redirecting a client.
+// If this attempt fails, an error is returned that must be returned
+// to the client instead.
+func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
+ e := oidc.DefaultToServerError(parent, parent.Error())
+ logger = logger.With("oidc_error", e)
+
+ if authReq == nil {
+ logger.Log(ctx, e.LogLevel(), "auth request")
+ return nil, AsStatusError(e, http.StatusBadRequest)
+ }
+
+ if logAuthReq, ok := authReq.(LogAuthRequest); ok {
+ logger = logger.With("auth_request", logAuthReq)
+ }
+
+ if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
+ logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
+ return nil, AsStatusError(e, http.StatusBadRequest)
+ }
+
+ e.State = authReq.GetState()
+ var sessionState string
+ authRequestSessionState, ok := authReq.(AuthRequestSessionState)
+ if ok {
+ sessionState = authRequestSessionState.GetSessionState()
+ }
+ e.SessionState = sessionState
+ var responseMode oidc.ResponseMode
+ if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
+ responseMode = rm.GetResponseMode()
+ }
+ url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
+ if err != nil {
+ logger.ErrorContext(ctx, "auth response URL", "error", err)
+ return nil, AsStatusError(err, http.StatusBadRequest)
+ }
+ logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
+ return NewRedirect(url), nil
}
-func (e *OAuthError) Error() string {
- return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
+// StatusError wraps an error with a HTTP status code.
+// The status code is passed to the handler's writer.
+type StatusError struct {
+ parent error
+ statusCode int
+}
+
+// NewStatusError sets the parent and statusCode to a new StatusError.
+// It is recommended for parent to be an [oidc.Error].
+//
+// Typically implementations should only use this to signal something
+// very specific, like an internal server error.
+// If a returned error is not a StatusError, the framework
+// will set a statusCode based on what the standard specifies,
+// which is [http.StatusBadRequest] for most of the time.
+// If the error encountered can described clearly with a [oidc.Error],
+// do not use this function, as it might break standard rules!
+func NewStatusError(parent error, statusCode int) StatusError {
+ return StatusError{
+ parent: parent,
+ statusCode: statusCode,
+ }
+}
+
+// AsStatusError unwraps a StatusError from err
+// and returns it unmodified if found.
+// If no StatuError was found, a new one is returned
+// with statusCode set to it as a default.
+func AsStatusError(err error, statusCode int) (target StatusError) {
+ if errors.As(err, &target) {
+ return target
+ }
+ return NewStatusError(err, statusCode)
+}
+
+func (e StatusError) Error() string {
+ return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
+}
+
+func (e StatusError) Unwrap() error {
+ return e.parent
+}
+
+func (e StatusError) Is(err error) bool {
+ var target StatusError
+ if !errors.As(err, &target) {
+ return false
+ }
+ return errors.Is(e.parent, target.parent) &&
+ e.statusCode == target.statusCode
+}
+
+// WriteError asserts for a [StatusError] containing an [oidc.Error].
+// If no `StatusError` is found, the status code will default to [http.StatusBadRequest].
+// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError].
+// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`,
+// the status code will be set to [http.StatusInternalServerError]
+func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
+ var statusError StatusError
+ if errors.As(err, &statusError) {
+ writeError(w, r,
+ oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()),
+ statusError.statusCode, logger,
+ )
+ return
+ }
+ statusCode := http.StatusBadRequest
+ e := oidc.DefaultToServerError(err, err.Error())
+ if e.ErrorType == oidc.ServerError {
+ statusCode = http.StatusInternalServerError
+ }
+ writeError(w, r, e, statusCode, logger)
+}
+
+func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) {
+ logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode)
+ httphelper.MarshalJSONWithStatus(w, err, statusCode)
}
diff --git a/pkg/op/error_test.go b/pkg/op/error_test.go
new file mode 100644
index 0000000..9271cf1
--- /dev/null
+++ b/pkg/op/error_test.go
@@ -0,0 +1,682 @@
+package op
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/schema"
+)
+
+func TestAuthRequestError(t *testing.T) {
+ type args struct {
+ authReq ErrAuthRequest
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ wantCode int
+ wantHeaders map[string]string
+ wantBody string
+ wantLog string
+ }{
+ {
+ name: "nil auth request",
+ args: args{
+ authReq: nil,
+ err: io.ErrClosedPipe,
+ },
+ wantCode: http.StatusBadRequest,
+ wantBody: "io: read/write on closed pipe\n",
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"auth request",
+ "time":"not",
+ "oidc_error":{
+ "description":"io: read/write on closed pipe",
+ "parent":"io: read/write on closed pipe",
+ "type":"server_error"
+ }
+ }`,
+ },
+ {
+ name: "auth request, no redirect URI",
+ args: args{
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ err: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ wantCode: http.StatusBadRequest,
+ wantBody: "sign in\n",
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request: not redirecting",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ }
+ }`,
+ },
+ {
+ name: "auth request, redirect disabled",
+ args: args{
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "http://example.com/callback",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ err: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
+ },
+ wantCode: http.StatusBadRequest,
+ wantBody: "oops\n",
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request: not redirecting",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"http://example.com/callback",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"oops",
+ "type":"invalid_request",
+ "redirect_disabled":true
+ }
+ }`,
+ },
+ {
+ name: "auth request, url parse error",
+ args: args{
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "can't parse this!\n",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ err: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ wantCode: http.StatusBadRequest,
+ wantBody: "ErrorType=server_error Parent=parse \"can't parse this!\\n\": net/url: invalid control character in URL\n",
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"auth response URL",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"can't parse this!\n",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "error":{
+ "type":"server_error",
+ "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ }
+ }`,
+ },
+ {
+ name: "auth request redirect",
+ args: args{
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "http://example.com/callback",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ err: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ wantCode: http.StatusFound,
+ wantHeaders: map[string]string{"Location": "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"},
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"http://example.com/callback",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ }
+ }`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ logOut := new(strings.Builder)
+ authorizer := &Provider{
+ encoder: schema.NewEncoder(),
+ logger: slog.New(
+ slog.NewJSONHandler(logOut, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }).WithAttrs([]slog.Attr{slog.String("time", "not")}),
+ ),
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("POST", "/path", nil)
+ AuthRequestError(w, r, tt.args.authReq, tt.args.err, authorizer)
+
+ res := w.Result()
+ defer res.Body.Close()
+
+ assert.Equal(t, tt.wantCode, res.StatusCode)
+ for key, wantHeader := range tt.wantHeaders {
+ gotHeader := res.Header.Get(key)
+ assert.Equalf(t, wantHeader, gotHeader, "header %q", key)
+ }
+ gotBody, err := io.ReadAll(res.Body)
+ require.NoError(t, err, "read result body")
+ assert.Equal(t, tt.wantBody, string(gotBody), "result body")
+
+ gotLog := logOut.String()
+ t.Log(gotLog)
+ assert.JSONEq(t, tt.wantLog, gotLog, "log output")
+ })
+ }
+}
+
+func TestRequestError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ wantCode int
+ wantBody string
+ wantLog string
+ }{
+ {
+ name: "server error",
+ err: io.ErrClosedPipe,
+ wantCode: http.StatusBadRequest,
+ wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"request error",
+ "time":"not",
+ "oidc_error":{
+ "parent":"io: read/write on closed pipe",
+ "description":"io: read/write on closed pipe",
+ "type":"server_error"}
+ }`,
+ },
+ {
+ name: "invalid client",
+ err: oidc.ErrInvalidClient().WithDescription("not good"),
+ wantCode: http.StatusUnauthorized,
+ wantBody: `{"error":"invalid_client", "error_description":"not good"}`,
+ wantLog: `{
+ "level":"WARN",
+ "msg":"request error",
+ "time":"not",
+ "oidc_error":{
+ "description":"not good",
+ "type":"invalid_client"}
+ }`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ logOut := new(strings.Builder)
+ logger := slog.New(
+ slog.NewJSONHandler(logOut, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }).WithAttrs([]slog.Attr{slog.String("time", "not")}),
+ )
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("POST", "/path", nil)
+ RequestError(w, r, tt.err, logger)
+
+ res := w.Result()
+ defer res.Body.Close()
+
+ assert.Equal(t, tt.wantCode, res.StatusCode, "status code")
+
+ gotBody, err := io.ReadAll(res.Body)
+ require.NoError(t, err, "read result body")
+ assert.JSONEq(t, tt.wantBody, string(gotBody), "result body")
+
+ gotLog := logOut.String()
+ t.Log(gotLog)
+ assert.JSONEq(t, tt.wantLog, gotLog, "log output")
+ })
+ }
+}
+
+func TestTryErrorRedirect(t *testing.T) {
+ type args struct {
+ ctx context.Context
+ authReq ErrAuthRequest
+ parent error
+ }
+ tests := []struct {
+ name string
+ args args
+ want *Redirect
+ wantErr error
+ wantLog string
+ }{
+ {
+ name: "nil auth request",
+ args: args{
+ ctx: context.Background(),
+ authReq: nil,
+ parent: io.ErrClosedPipe,
+ },
+ wantErr: NewStatusError(io.ErrClosedPipe, http.StatusBadRequest),
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"auth request",
+ "time":"not",
+ "oidc_error":{
+ "description":"io: read/write on closed pipe",
+ "parent":"io: read/write on closed pipe",
+ "type":"server_error"
+ }
+ }`,
+ },
+ {
+ name: "auth request, no redirect URI",
+ args: args{
+ ctx: context.Background(),
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ wantErr: NewStatusError(oidc.ErrInteractionRequired().WithDescription("sign in"), http.StatusBadRequest),
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request: not redirecting",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ }
+ }`,
+ },
+ {
+ name: "auth request, redirect disabled",
+ args: args{
+ ctx: context.Background(),
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "http://example.com/callback",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ parent: oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"),
+ },
+ wantErr: NewStatusError(oidc.ErrInvalidRequestRedirectURI().WithDescription("oops"), http.StatusBadRequest),
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request: not redirecting",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"http://example.com/callback",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"oops",
+ "type":"invalid_request",
+ "redirect_disabled":true
+ }
+ }`,
+ },
+ {
+ name: "auth request, url parse error",
+ args: args{
+ ctx: context.Background(),
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "can't parse this!\n",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ wantErr: func() error {
+ //lint:ignore SA1007 just recreating the error for testing
+ _, err := url.Parse("can't parse this!\n")
+ err = oidc.ErrServerError().WithParent(err)
+ return NewStatusError(err, http.StatusBadRequest)
+ }(),
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"auth response URL",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"can't parse this!\n",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "error":{
+ "type":"server_error",
+ "parent":"parse \"can't parse this!\\n\": net/url: invalid control character in URL"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ }
+ }`,
+ },
+ {
+ name: "auth request redirect",
+ args: args{
+ ctx: context.Background(),
+ authReq: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"a", "b"},
+ ResponseType: "responseType",
+ ClientID: "123",
+ RedirectURI: "http://example.com/callback",
+ State: "state1",
+ ResponseMode: oidc.ResponseModeQuery,
+ },
+ parent: oidc.ErrInteractionRequired().WithDescription("sign in"),
+ },
+ want: &Redirect{
+ Header: make(http.Header),
+ URL: "http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1",
+ },
+ wantLog: `{
+ "level":"WARN",
+ "msg":"auth request redirect",
+ "time":"not",
+ "auth_request":{
+ "client_id":"123",
+ "redirect_uri":"http://example.com/callback",
+ "response_type":"responseType",
+ "scopes":"a b"
+ },
+ "oidc_error":{
+ "description":"sign in",
+ "type":"interaction_required"
+ },
+ "url":"http://example.com/callback?error=interaction_required&error_description=sign+in&state=state1"
+ }`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ logOut := new(strings.Builder)
+ logger := slog.New(
+ slog.NewJSONHandler(logOut, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }).WithAttrs([]slog.Attr{slog.String("time", "not")}),
+ )
+ encoder := schema.NewEncoder()
+
+ got, err := TryErrorRedirect(tt.args.ctx, tt.args.authReq, tt.args.parent, encoder, logger)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+
+ gotLog := logOut.String()
+ t.Log(gotLog)
+ assert.JSONEq(t, tt.wantLog, gotLog, "log output")
+ })
+ }
+}
+
+func TestNewStatusError(t *testing.T) {
+ err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
+
+ want := "Internal Server Error: io: read/write on closed pipe"
+ got := fmt.Sprint(err)
+ assert.Equal(t, want, got)
+}
+
+func TestAsStatusError(t *testing.T) {
+ type args struct {
+ err error
+ statusCode int
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {
+ name: "already status error",
+ args: args{
+ err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
+ statusCode: http.StatusBadRequest,
+ },
+ want: "Internal Server Error: io: read/write on closed pipe",
+ },
+ {
+ name: "oidc error",
+ args: args{
+ err: oidc.ErrAcrInvalid,
+ statusCode: http.StatusBadRequest,
+ },
+ want: "Bad Request: acr is invalid",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := AsStatusError(tt.args.err, tt.args.statusCode)
+ got := fmt.Sprint(err)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestStatusError_Unwrap(t *testing.T) {
+ err := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
+ require.ErrorIs(t, err, io.ErrClosedPipe)
+}
+
+func TestStatusError_Is(t *testing.T) {
+ type args struct {
+ err error
+ }
+ tests := []struct {
+ name string
+ args args
+ want bool
+ }{
+ {
+ name: "nil error",
+ args: args{err: nil},
+ want: false,
+ },
+ {
+ name: "other error",
+ args: args{err: io.EOF},
+ want: false,
+ },
+ {
+ name: "other parent",
+ args: args{err: NewStatusError(io.EOF, http.StatusInternalServerError)},
+ want: false,
+ },
+ {
+ name: "other status",
+ args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInsufficientStorage)},
+ want: false,
+ },
+ {
+ name: "same",
+ args: args{err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)},
+ want: true,
+ },
+ {
+ name: "wrapped",
+ args: args{err: fmt.Errorf("wrap: %w", NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError))},
+ want: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ e := NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError)
+ if got := e.Is(tt.args.err); got != tt.want {
+ t.Errorf("StatusError.Is() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestWriteError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ wantStatus int
+ wantBody string
+ wantLog string
+ }{
+ {
+ name: "not a status or oidc error",
+ err: io.ErrClosedPipe,
+ wantStatus: http.StatusInternalServerError,
+ wantBody: `{
+ "error":"server_error",
+ "error_description":"io: read/write on closed pipe"
+ }`,
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"request error",
+ "oidc_error":{
+ "description":"io: read/write on closed pipe",
+ "parent":"io: read/write on closed pipe",
+ "type":"server_error"
+ },
+ "status_code":500,
+ "time":"not"
+ }`,
+ },
+ {
+ name: "status error w/o oidc",
+ err: NewStatusError(io.ErrClosedPipe, http.StatusInternalServerError),
+ wantStatus: http.StatusInternalServerError,
+ wantBody: `{
+ "error":"server_error",
+ "error_description":"io: read/write on closed pipe"
+ }`,
+ wantLog: `{
+ "level":"ERROR",
+ "msg":"request error",
+ "oidc_error":{
+ "description":"io: read/write on closed pipe",
+ "parent":"io: read/write on closed pipe",
+ "type":"server_error"
+ },
+ "status_code":500,
+ "time":"not"
+ }`,
+ },
+ {
+ name: "oidc error w/o status",
+ err: oidc.ErrInvalidRequest().WithDescription("oops"),
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{
+ "error":"invalid_request",
+ "error_description":"oops"
+ }`,
+ wantLog: `{
+ "level":"WARN",
+ "msg":"request error",
+ "oidc_error":{
+ "description":"oops",
+ "type":"invalid_request"
+ },
+ "status_code":400,
+ "time":"not"
+ }`,
+ },
+ {
+ name: "status with oidc error",
+ err: NewStatusError(
+ oidc.ErrUnauthorizedClient().WithDescription("oops"),
+ http.StatusUnauthorized,
+ ),
+ wantStatus: http.StatusUnauthorized,
+ wantBody: `{
+ "error":"unauthorized_client",
+ "error_description":"oops"
+ }`,
+ wantLog: `{
+ "level":"WARN",
+ "msg":"request error",
+ "oidc_error":{
+ "description":"oops",
+ "type":"unauthorized_client"
+ },
+ "status_code":401,
+ "time":"not"
+ }`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ logOut := new(strings.Builder)
+ logger := slog.New(
+ slog.NewJSONHandler(logOut, &slog.HandlerOptions{
+ Level: slog.LevelInfo,
+ }).WithAttrs([]slog.Attr{slog.String("time", "not")}),
+ )
+ r := httptest.NewRequest("GET", "/target", nil)
+ w := httptest.NewRecorder()
+
+ WriteError(w, r, tt.err, logger)
+ res := w.Result()
+ assert.Equal(t, tt.wantStatus, res.StatusCode, "status code")
+ gotBody, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+ assert.JSONEq(t, tt.wantBody, string(gotBody), "body")
+ assert.JSONEq(t, tt.wantLog, logOut.String())
+ })
+ }
+}
diff --git a/pkg/op/form_post.html.tmpl b/pkg/op/form_post.html.tmpl
new file mode 100644
index 0000000..7bc9ab3
--- /dev/null
+++ b/pkg/op/form_post.html.tmpl
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/pkg/op/keys.go b/pkg/op/keys.go
index 8e2052b..97e400b 100644
--- a/pkg/op/keys.go
+++ b/pkg/op/keys.go
@@ -1,19 +1,46 @@
package op
import (
+ "context"
"net/http"
- "github.com/caos/oidc/pkg/utils"
+ jose "github.com/go-jose/go-jose/v4"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
)
type KeyProvider interface {
- Storage() Storage
+ KeySet(context.Context) ([]Key, error)
+}
+
+func keysHandler(k KeyProvider) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Keys(w, r, k)
+ }
}
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
- keySet, err := k.Storage().GetKeySet(r.Context())
- if err != nil {
+ ctx, span := tracer.Start(r.Context(), "Keys")
+ r = r.WithContext(ctx)
+ defer span.End()
+ keySet, err := k.KeySet(r.Context())
+ if err != nil {
+ httphelper.MarshalJSONWithStatus(w, err, http.StatusInternalServerError)
+ return
}
- utils.MarshalJSON(w, keySet)
+ httphelper.MarshalJSON(w, jsonWebKeySet(keySet))
+}
+
+func jsonWebKeySet(keys []Key) *jose.JSONWebKeySet {
+ webKeys := make([]jose.JSONWebKey, len(keys))
+ for i, key := range keys {
+ webKeys[i] = jose.JSONWebKey{
+ KeyID: key.ID(),
+ Algorithm: string(key.Algorithm()),
+ Use: key.Use(),
+ Key: key.Key(),
+ }
+ }
+ return &jose.JSONWebKeySet{Keys: webKeys}
}
diff --git a/pkg/op/keys_test.go b/pkg/op/keys_test.go
new file mode 100644
index 0000000..9c80878
--- /dev/null
+++ b/pkg/op/keys_test.go
@@ -0,0 +1,100 @@
+package op_test
+
+import (
+ "crypto/rsa"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/golang/mock/gomock"
+ "github.com/stretchr/testify/assert"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op/mock"
+)
+
+func TestKeys(t *testing.T) {
+ type args struct {
+ k op.KeyProvider
+ }
+ type res struct {
+ statusCode int
+ contentType string
+ body string
+ }
+ tests := []struct {
+ name string
+ args args
+ res res
+ }{
+ {
+ name: "error",
+ args: args{
+ k: func() op.KeyProvider {
+ m := mock.NewMockKeyProvider(gomock.NewController(t))
+ m.EXPECT().KeySet(gomock.Any()).Return(nil, oidc.ErrServerError())
+ return m
+ }(),
+ },
+ res: res{
+ statusCode: http.StatusInternalServerError,
+ contentType: "application/json",
+ body: `{"error":"server_error"}
+`,
+ },
+ },
+ {
+ name: "empty list",
+ args: args{
+ k: func() op.KeyProvider {
+ m := mock.NewMockKeyProvider(gomock.NewController(t))
+ m.EXPECT().KeySet(gomock.Any()).Return(nil, nil)
+ return m
+ }(),
+ },
+ res: res{
+ statusCode: http.StatusOK,
+ contentType: "application/json",
+ body: `{"keys":[]}
+`,
+ },
+ },
+ {
+ name: "list",
+ args: args{
+ k: func() op.KeyProvider {
+ ctrl := gomock.NewController(t)
+ m := mock.NewMockKeyProvider(ctrl)
+ k := mock.NewMockKey(ctrl)
+ k.EXPECT().Key().Return(&rsa.PublicKey{
+ N: big.NewInt(1),
+ E: 1,
+ })
+ k.EXPECT().ID().Return("id")
+ k.EXPECT().Algorithm().Return(jose.RS256)
+ k.EXPECT().Use().Return("sig")
+ m.EXPECT().KeySet(gomock.Any()).Return([]op.Key{k}, nil)
+ return m
+ }(),
+ },
+ res: res{
+ statusCode: http.StatusOK,
+ contentType: "application/json",
+ body: `{"keys":[{"use":"sig","kty":"RSA","kid":"id","alg":"RS256","n":"AQ","e":"AQ"}]}
+`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ op.Keys(w, httptest.NewRequest("GET", "/keys", nil), tt.args.k)
+ assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
+ assert.Equal(t, tt.res.contentType, w.Header().Get("content-type"))
+ assert.Equal(t, tt.res.body, w.Body.String())
+ })
+ }
+}
diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go
index dbfc2a6..56b28e0 100644
--- a/pkg/op/mock/authorizer.mock.go
+++ b/pkg/op/mock/authorizer.mock.go
@@ -1,41 +1,43 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/caos/oidc/pkg/op (interfaces: Authorizer)
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Authorizer)
// Package mock is a generated GoMock package.
package mock
import (
- op "github.com/caos/oidc/pkg/op"
- rp "github.com/caos/oidc/pkg/rp"
- gomock "github.com/golang/mock/gomock"
- schema "github.com/gorilla/schema"
+ context "context"
+ slog "log/slog"
reflect "reflect"
+
+ http "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
)
-// MockAuthorizer is a mock of Authorizer interface
+// MockAuthorizer is a mock of Authorizer interface.
type MockAuthorizer struct {
ctrl *gomock.Controller
recorder *MockAuthorizerMockRecorder
}
-// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer
+// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer.
type MockAuthorizerMockRecorder struct {
mock *MockAuthorizer
}
-// NewMockAuthorizer creates a new mock instance
+// NewMockAuthorizer creates a new mock instance.
func NewMockAuthorizer(ctrl *gomock.Controller) *MockAuthorizer {
mock := &MockAuthorizer{ctrl: ctrl}
mock.recorder = &MockAuthorizerMockRecorder{mock}
return mock
}
-// EXPECT returns an object that allows the caller to indicate expected use
+// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder {
return m.recorder
}
-// Crypto mocks base method
+// Crypto mocks base method.
func (m *MockAuthorizer) Crypto() op.Crypto {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Crypto")
@@ -43,83 +45,83 @@ func (m *MockAuthorizer) Crypto() op.Crypto {
return ret0
}
-// Crypto indicates an expected call of Crypto
+// Crypto indicates an expected call of Crypto.
func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Crypto", reflect.TypeOf((*MockAuthorizer)(nil).Crypto))
}
-// Decoder mocks base method
-func (m *MockAuthorizer) Decoder() *schema.Decoder {
+// Decoder mocks base method.
+func (m *MockAuthorizer) Decoder() http.Decoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decoder")
- ret0, _ := ret[0].(*schema.Decoder)
+ ret0, _ := ret[0].(http.Decoder)
return ret0
}
-// Decoder indicates an expected call of Decoder
+// Decoder indicates an expected call of Decoder.
func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decoder", reflect.TypeOf((*MockAuthorizer)(nil).Decoder))
}
-// Encoder mocks base method
-func (m *MockAuthorizer) Encoder() *schema.Encoder {
+// Encoder mocks base method.
+func (m *MockAuthorizer) Encoder() http.Encoder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encoder")
- ret0, _ := ret[0].(*schema.Encoder)
+ ret0, _ := ret[0].(http.Encoder)
return ret0
}
-// Encoder indicates an expected call of Encoder
+// Encoder indicates an expected call of Encoder.
func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
}
-// IDTokenVerifier mocks base method
-func (m *MockAuthorizer) IDTokenVerifier() rp.Verifier {
+// IDTokenHintVerifier mocks base method.
+func (m *MockAuthorizer) IDTokenHintVerifier(arg0 context.Context) *op.IDTokenHintVerifier {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "IDTokenVerifier")
- ret0, _ := ret[0].(rp.Verifier)
+ ret := m.ctrl.Call(m, "IDTokenHintVerifier", arg0)
+ ret0, _ := ret[0].(*op.IDTokenHintVerifier)
return ret0
}
-// IDTokenVerifier indicates an expected call of IDTokenVerifier
-func (mr *MockAuthorizerMockRecorder) IDTokenVerifier() *gomock.Call {
+// IDTokenHintVerifier indicates an expected call of IDTokenHintVerifier.
+func (mr *MockAuthorizerMockRecorder) IDTokenHintVerifier(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenVerifier))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenHintVerifier", reflect.TypeOf((*MockAuthorizer)(nil).IDTokenHintVerifier), arg0)
}
-// Issuer mocks base method
-func (m *MockAuthorizer) Issuer() string {
+// Logger mocks base method.
+func (m *MockAuthorizer) Logger() *slog.Logger {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Issuer")
- ret0, _ := ret[0].(string)
+ ret := m.ctrl.Call(m, "Logger")
+ ret0, _ := ret[0].(*slog.Logger)
return ret0
}
-// Issuer indicates an expected call of Issuer
-func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
+// Logger indicates an expected call of Logger.
+func (mr *MockAuthorizerMockRecorder) Logger() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockAuthorizer)(nil).Logger))
}
-// Signer mocks base method
-func (m *MockAuthorizer) Signer() op.Signer {
+// RequestObjectSupported mocks base method.
+func (m *MockAuthorizer) RequestObjectSupported() bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Signer")
- ret0, _ := ret[0].(op.Signer)
+ ret := m.ctrl.Call(m, "RequestObjectSupported")
+ ret0, _ := ret[0].(bool)
return ret0
}
-// Signer indicates an expected call of Signer
-func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
+// RequestObjectSupported indicates an expected call of RequestObjectSupported.
+func (mr *MockAuthorizerMockRecorder) RequestObjectSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockAuthorizer)(nil).RequestObjectSupported))
}
-// Storage mocks base method
+// Storage mocks base method.
func (m *MockAuthorizer) Storage() op.Storage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Storage")
@@ -127,7 +129,7 @@ func (m *MockAuthorizer) Storage() op.Storage {
return ret0
}
-// Storage indicates an expected call of Storage
+// Storage indicates an expected call of Storage.
func (mr *MockAuthorizerMockRecorder) Storage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockAuthorizer)(nil).Storage))
diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go
index 29c9354..73c4154 100644
--- a/pkg/op/mock/authorizer.mock.impl.go
+++ b/pkg/op/mock/authorizer.mock.impl.go
@@ -4,13 +4,12 @@ import (
"context"
"testing"
+ jose "github.com/go-jose/go-jose/v4"
"github.com/golang/mock/gomock"
- "github.com/gorilla/schema"
- "gopkg.in/square/go-jose.v2"
+ "github.com/zitadel/schema"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/op"
- "github.com/caos/oidc/pkg/rp"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
)
func NewAuthorizer(t *testing.T) op.Authorizer {
@@ -21,23 +20,13 @@ func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
m := NewAuthorizer(t)
ExpectDecoder(m)
ExpectEncoder(m)
- ExpectSigner(m, t)
+ //ExpectSigner(m, t)
ExpectStorage(m, t)
ExpectVerifier(m, t)
// ExpectErrorHandler(m, t, wantErr)
return m
}
-// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
-// m := NewAuthorizer(t)
-// ExpectDecoderFails(m)
-// ExpectEncoder(m)
-// ExpectSigner(m, t)
-// ExpectStorage(m, t)
-// ExpectErrorHandler(m, t)
-// return m
-// }
-
func ExpectDecoder(a op.Authorizer) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
@@ -48,19 +37,20 @@ func ExpectEncoder(a op.Authorizer) {
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
}
-func ExpectSigner(a op.Authorizer, t *testing.T) {
- mockA := a.(*MockAuthorizer)
- mockA.EXPECT().Signer().DoAndReturn(
- func() op.Signer {
- return &Sig{}
- })
-}
+//
+//func ExpectSigner(a op.Authorizer, t *testing.T) {
+// mockA := a.(*MockAuthorizer)
+// mockA.EXPECT().Signer().DoAndReturn(
+// func() op.Signer {
+// return &Sig{}
+// })
+//}
func ExpectVerifier(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
- mockA.EXPECT().IDTokenVerifier().DoAndReturn(
- func() rp.Verifier {
- return &Verifier{}
+ mockA.EXPECT().IDTokenHintVerifier(gomock.Any()).DoAndReturn(
+ func() *op.IDTokenHintVerifier {
+ return op.NewIDTokenHintVerifier("", nil)
})
}
@@ -70,18 +60,22 @@ func (v *Verifier) Verify(ctx context.Context, accessToken, idToken string) (*oi
return nil, nil
}
-type Sig struct{}
+func (v *Verifier) VerifyIDToken(ctx context.Context, idToken string) (*oidc.IDTokenClaims, error) {
+ return nil, nil
+}
+
+type Sig struct {
+ signer jose.Signer
+}
+
+func (s *Sig) Signer() jose.Signer {
+ return s.signer
+}
func (s *Sig) Health(ctx context.Context) error {
return nil
}
-func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
- return "", nil
-}
-func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
- return "", nil
-}
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
return jose.HS256
}
@@ -90,9 +84,3 @@ func ExpectStorage(a op.Authorizer, t *testing.T) {
mockA := a.(*MockAuthorizer)
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
}
-
-// func NewMockSignerAny(t *testing.T) op.Signer {
-// m := NewMockSigner(gomock.NewController(t))
-// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
-// return m
-// }
diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go
index 242eb13..e2a5e85 100644
--- a/pkg/op/mock/client.go
+++ b/pkg/op/mock/client.go
@@ -3,9 +3,10 @@ package mock
import (
"testing"
- gomock "github.com/golang/mock/gomock"
+ "github.com/golang/mock/gomock"
- op "github.com/caos/oidc/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
)
func NewClient(t *testing.T) op.Client {
@@ -19,11 +20,23 @@ func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
"https://registered.com/callback",
"http://registered.com/callback",
"http://localhost:9999/callback",
- "custom://callback"})
+ "custom://callback",
+ })
m.EXPECT().ApplicationType().AnyTimes().Return(appType)
m.EXPECT().LoginURL(gomock.Any()).AnyTimes().DoAndReturn(
func(id string) string {
return "login?id=" + id
})
+ m.EXPECT().IsScopeAllowed(gomock.Any()).AnyTimes().Return(false)
+ return c
+}
+
+func NewClientWithConfig(t *testing.T, uri []string, appType op.ApplicationType, responseTypes []oidc.ResponseType, devMode bool) op.Client {
+ c := NewClient(t)
+ m := c.(*MockClient)
+ m.EXPECT().RedirectURIs().AnyTimes().Return(uri)
+ m.EXPECT().ApplicationType().AnyTimes().Return(appType)
+ m.EXPECT().ResponseTypes().AnyTimes().Return(responseTypes)
+ m.EXPECT().DevMode().AnyTimes().Return(devMode)
return c
}
diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go
index e2f1c11..93eca67 100644
--- a/pkg/op/mock/client.mock.go
+++ b/pkg/op/mock/client.mock.go
@@ -1,40 +1,42 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/caos/oidc/pkg/op (interfaces: Client)
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Client)
// Package mock is a generated GoMock package.
package mock
import (
- op "github.com/caos/oidc/pkg/op"
- gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
+
+ oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
)
-// MockClient is a mock of Client interface
+// MockClient is a mock of Client interface.
type MockClient struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder
}
-// MockClientMockRecorder is the mock recorder for MockClient
+// MockClientMockRecorder is the mock recorder for MockClient.
type MockClientMockRecorder struct {
mock *MockClient
}
-// NewMockClient creates a new mock instance
+// NewMockClient creates a new mock instance.
func NewMockClient(ctrl *gomock.Controller) *MockClient {
mock := &MockClient{ctrl: ctrl}
mock.recorder = &MockClientMockRecorder{mock}
return mock
}
-// EXPECT returns an object that allows the caller to indicate expected use
+// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockClient) EXPECT() *MockClientMockRecorder {
return m.recorder
}
-// AccessTokenType mocks base method
+// AccessTokenType mocks base method.
func (m *MockClient) AccessTokenType() op.AccessTokenType {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AccessTokenType")
@@ -42,13 +44,13 @@ func (m *MockClient) AccessTokenType() op.AccessTokenType {
return ret0
}
-// AccessTokenType indicates an expected call of AccessTokenType
+// AccessTokenType indicates an expected call of AccessTokenType.
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
}
-// ApplicationType mocks base method
+// ApplicationType mocks base method.
func (m *MockClient) ApplicationType() op.ApplicationType {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ApplicationType")
@@ -56,27 +58,55 @@ func (m *MockClient) ApplicationType() op.ApplicationType {
return ret0
}
-// ApplicationType indicates an expected call of ApplicationType
+// ApplicationType indicates an expected call of ApplicationType.
func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
}
-// GetAuthMethod mocks base method
-func (m *MockClient) GetAuthMethod() op.AuthMethod {
+// AuthMethod mocks base method.
+func (m *MockClient) AuthMethod() oidc.AuthMethod {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetAuthMethod")
- ret0, _ := ret[0].(op.AuthMethod)
+ ret := m.ctrl.Call(m, "AuthMethod")
+ ret0, _ := ret[0].(oidc.AuthMethod)
return ret0
}
-// GetAuthMethod indicates an expected call of GetAuthMethod
-func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
+// AuthMethod indicates an expected call of AuthMethod.
+func (mr *MockClientMockRecorder) AuthMethod() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethod", reflect.TypeOf((*MockClient)(nil).AuthMethod))
}
-// GetID mocks base method
+// ClockSkew mocks base method.
+func (m *MockClient) ClockSkew() time.Duration {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ClockSkew")
+ ret0, _ := ret[0].(time.Duration)
+ return ret0
+}
+
+// ClockSkew indicates an expected call of ClockSkew.
+func (mr *MockClientMockRecorder) ClockSkew() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClockSkew", reflect.TypeOf((*MockClient)(nil).ClockSkew))
+}
+
+// DevMode mocks base method.
+func (m *MockClient) DevMode() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DevMode")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// DevMode indicates an expected call of DevMode.
+func (mr *MockClientMockRecorder) DevMode() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DevMode", reflect.TypeOf((*MockClient)(nil).DevMode))
+}
+
+// GetID mocks base method.
func (m *MockClient) GetID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetID")
@@ -84,13 +114,27 @@ func (m *MockClient) GetID() string {
return ret0
}
-// GetID indicates an expected call of GetID
+// GetID indicates an expected call of GetID.
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
}
-// IDTokenLifetime mocks base method
+// GrantTypes mocks base method.
+func (m *MockClient) GrantTypes() []oidc.GrantType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypes")
+ ret0, _ := ret[0].([]oidc.GrantType)
+ return ret0
+}
+
+// GrantTypes indicates an expected call of GrantTypes.
+func (mr *MockClientMockRecorder) GrantTypes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockClient)(nil).GrantTypes))
+}
+
+// IDTokenLifetime mocks base method.
func (m *MockClient) IDTokenLifetime() time.Duration {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IDTokenLifetime")
@@ -98,13 +142,41 @@ func (m *MockClient) IDTokenLifetime() time.Duration {
return ret0
}
-// IDTokenLifetime indicates an expected call of IDTokenLifetime
+// IDTokenLifetime indicates an expected call of IDTokenLifetime.
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
}
-// LoginURL mocks base method
+// IDTokenUserinfoClaimsAssertion mocks base method.
+func (m *MockClient) IDTokenUserinfoClaimsAssertion() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IDTokenUserinfoClaimsAssertion")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IDTokenUserinfoClaimsAssertion indicates an expected call of IDTokenUserinfoClaimsAssertion.
+func (mr *MockClientMockRecorder) IDTokenUserinfoClaimsAssertion() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenUserinfoClaimsAssertion", reflect.TypeOf((*MockClient)(nil).IDTokenUserinfoClaimsAssertion))
+}
+
+// IsScopeAllowed mocks base method.
+func (m *MockClient) IsScopeAllowed(arg0 string) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IsScopeAllowed", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IsScopeAllowed indicates an expected call of IsScopeAllowed.
+func (mr *MockClientMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockClient)(nil).IsScopeAllowed), arg0)
+}
+
+// LoginURL mocks base method.
func (m *MockClient) LoginURL(arg0 string) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginURL", arg0)
@@ -112,13 +184,13 @@ func (m *MockClient) LoginURL(arg0 string) string {
return ret0
}
-// LoginURL indicates an expected call of LoginURL
+// LoginURL indicates an expected call of LoginURL.
func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
}
-// PostLogoutRedirectURIs mocks base method
+// PostLogoutRedirectURIs mocks base method.
func (m *MockClient) PostLogoutRedirectURIs() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PostLogoutRedirectURIs")
@@ -126,13 +198,13 @@ func (m *MockClient) PostLogoutRedirectURIs() []string {
return ret0
}
-// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs
+// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs.
func (mr *MockClientMockRecorder) PostLogoutRedirectURIs() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockClient)(nil).PostLogoutRedirectURIs))
}
-// RedirectURIs mocks base method
+// RedirectURIs mocks base method.
func (m *MockClient) RedirectURIs() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RedirectURIs")
@@ -140,8 +212,50 @@ func (m *MockClient) RedirectURIs() []string {
return ret0
}
-// RedirectURIs indicates an expected call of RedirectURIs
+// RedirectURIs indicates an expected call of RedirectURIs.
func (mr *MockClientMockRecorder) RedirectURIs() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockClient)(nil).RedirectURIs))
}
+
+// ResponseTypes mocks base method.
+func (m *MockClient) ResponseTypes() []oidc.ResponseType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ResponseTypes")
+ ret0, _ := ret[0].([]oidc.ResponseType)
+ return ret0
+}
+
+// ResponseTypes indicates an expected call of ResponseTypes.
+func (mr *MockClientMockRecorder) ResponseTypes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockClient)(nil).ResponseTypes))
+}
+
+// RestrictAdditionalAccessTokenScopes mocks base method.
+func (m *MockClient) RestrictAdditionalAccessTokenScopes() func([]string) []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes")
+ ret0, _ := ret[0].(func([]string) []string)
+ return ret0
+}
+
+// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes.
+func (mr *MockClientMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalAccessTokenScopes))
+}
+
+// RestrictAdditionalIdTokenScopes mocks base method.
+func (m *MockClient) RestrictAdditionalIdTokenScopes() func([]string) []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes")
+ ret0, _ := ret[0].(func([]string) []string)
+ return ret0
+}
+
+// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes.
+func (mr *MockClientMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockClient)(nil).RestrictAdditionalIdTokenScopes))
+}
diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go
index c6174ff..bf51035 100644
--- a/pkg/op/mock/configuration.mock.go
+++ b/pkg/op/mock/configuration.mock.go
@@ -1,39 +1,42 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/caos/oidc/pkg/op (interfaces: Configuration)
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Configuration)
// Package mock is a generated GoMock package.
package mock
import (
- op "github.com/caos/oidc/pkg/op"
- gomock "github.com/golang/mock/gomock"
+ http "net/http"
reflect "reflect"
+
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
+ language "golang.org/x/text/language"
)
-// MockConfiguration is a mock of Configuration interface
+// MockConfiguration is a mock of Configuration interface.
type MockConfiguration struct {
ctrl *gomock.Controller
recorder *MockConfigurationMockRecorder
}
-// MockConfigurationMockRecorder is the mock recorder for MockConfiguration
+// MockConfigurationMockRecorder is the mock recorder for MockConfiguration.
type MockConfigurationMockRecorder struct {
mock *MockConfiguration
}
-// NewMockConfiguration creates a new mock instance
+// NewMockConfiguration creates a new mock instance.
func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration {
mock := &MockConfiguration{ctrl: ctrl}
mock.recorder = &MockConfigurationMockRecorder{mock}
return mock
}
-// EXPECT returns an object that allows the caller to indicate expected use
+// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder {
return m.recorder
}
-// AuthMethodPostSupported mocks base method
+// AuthMethodPostSupported mocks base method.
func (m *MockConfiguration) AuthMethodPostSupported() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthMethodPostSupported")
@@ -41,105 +44,413 @@ func (m *MockConfiguration) AuthMethodPostSupported() bool {
return ret0
}
-// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported
+// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported.
func (mr *MockConfigurationMockRecorder) AuthMethodPostSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPostSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPostSupported))
}
-// AuthorizationEndpoint mocks base method
-func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
+// AuthMethodPrivateKeyJWTSupported mocks base method.
+func (m *MockConfiguration) AuthMethodPrivateKeyJWTSupported() bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "AuthorizationEndpoint")
- ret0, _ := ret[0].(op.Endpoint)
+ ret := m.ctrl.Call(m, "AuthMethodPrivateKeyJWTSupported")
+ ret0, _ := ret[0].(bool)
return ret0
}
-// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint
+// AuthMethodPrivateKeyJWTSupported indicates an expected call of AuthMethodPrivateKeyJWTSupported.
+func (mr *MockConfigurationMockRecorder) AuthMethodPrivateKeyJWTSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPrivateKeyJWTSupported))
+}
+
+// AuthorizationEndpoint mocks base method.
+func (m *MockConfiguration) AuthorizationEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AuthorizationEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint.
func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
}
-// EndSessionEndpoint mocks base method
-func (m *MockConfiguration) EndSessionEndpoint() op.Endpoint {
+// BackChannelLogoutSessionSupported mocks base method.
+func (m *MockConfiguration) BackChannelLogoutSessionSupported() bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "EndSessionEndpoint")
- ret0, _ := ret[0].(op.Endpoint)
+ ret := m.ctrl.Call(m, "BackChannelLogoutSessionSupported")
+ ret0, _ := ret[0].(bool)
return ret0
}
-// EndSessionEndpoint indicates an expected call of EndSessionEndpoint
+// BackChannelLogoutSessionSupported indicates an expected call of BackChannelLogoutSessionSupported.
+func (mr *MockConfigurationMockRecorder) BackChannelLogoutSessionSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSessionSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSessionSupported))
+}
+
+// BackChannelLogoutSupported mocks base method.
+func (m *MockConfiguration) BackChannelLogoutSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "BackChannelLogoutSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// BackChannelLogoutSupported indicates an expected call of BackChannelLogoutSupported.
+func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported))
+}
+
+// CheckSessionIframe mocks base method.
+func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CheckSessionIframe")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// CheckSessionIframe indicates an expected call of CheckSessionIframe.
+func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe))
+}
+
+// CodeMethodS256Supported mocks base method.
+func (m *MockConfiguration) CodeMethodS256Supported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CodeMethodS256Supported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// CodeMethodS256Supported indicates an expected call of CodeMethodS256Supported.
+func (mr *MockConfigurationMockRecorder) CodeMethodS256Supported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CodeMethodS256Supported", reflect.TypeOf((*MockConfiguration)(nil).CodeMethodS256Supported))
+}
+
+// DeviceAuthorization mocks base method.
+func (m *MockConfiguration) DeviceAuthorization() op.DeviceAuthorizationConfig {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeviceAuthorization")
+ ret0, _ := ret[0].(op.DeviceAuthorizationConfig)
+ return ret0
+}
+
+// DeviceAuthorization indicates an expected call of DeviceAuthorization.
+func (mr *MockConfigurationMockRecorder) DeviceAuthorization() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorization", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorization))
+}
+
+// DeviceAuthorizationEndpoint mocks base method.
+func (m *MockConfiguration) DeviceAuthorizationEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeviceAuthorizationEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// DeviceAuthorizationEndpoint indicates an expected call of DeviceAuthorizationEndpoint.
+func (mr *MockConfigurationMockRecorder) DeviceAuthorizationEndpoint() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeviceAuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).DeviceAuthorizationEndpoint))
+}
+
+// EndSessionEndpoint mocks base method.
+func (m *MockConfiguration) EndSessionEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "EndSessionEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// EndSessionEndpoint indicates an expected call of EndSessionEndpoint.
func (mr *MockConfigurationMockRecorder) EndSessionEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndSessionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).EndSessionEndpoint))
}
-// Issuer mocks base method
-func (m *MockConfiguration) Issuer() string {
+// GrantTypeClientCredentialsSupported mocks base method.
+func (m *MockConfiguration) GrantTypeClientCredentialsSupported() bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Issuer")
+ ret := m.ctrl.Call(m, "GrantTypeClientCredentialsSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeClientCredentialsSupported indicates an expected call of GrantTypeClientCredentialsSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeClientCredentialsSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeClientCredentialsSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeClientCredentialsSupported))
+}
+
+// GrantTypeDeviceCodeSupported mocks base method.
+func (m *MockConfiguration) GrantTypeDeviceCodeSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypeDeviceCodeSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeDeviceCodeSupported indicates an expected call of GrantTypeDeviceCodeSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeDeviceCodeSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeDeviceCodeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeDeviceCodeSupported))
+}
+
+// GrantTypeJWTAuthorizationSupported mocks base method.
+func (m *MockConfiguration) GrantTypeJWTAuthorizationSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypeJWTAuthorizationSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeJWTAuthorizationSupported indicates an expected call of GrantTypeJWTAuthorizationSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeJWTAuthorizationSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeJWTAuthorizationSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeJWTAuthorizationSupported))
+}
+
+// GrantTypeRefreshTokenSupported mocks base method.
+func (m *MockConfiguration) GrantTypeRefreshTokenSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypeRefreshTokenSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeRefreshTokenSupported indicates an expected call of GrantTypeRefreshTokenSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeRefreshTokenSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeRefreshTokenSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeRefreshTokenSupported))
+}
+
+// GrantTypeTokenExchangeSupported mocks base method.
+func (m *MockConfiguration) GrantTypeTokenExchangeSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypeTokenExchangeSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// GrantTypeTokenExchangeSupported indicates an expected call of GrantTypeTokenExchangeSupported.
+func (mr *MockConfigurationMockRecorder) GrantTypeTokenExchangeSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypeTokenExchangeSupported", reflect.TypeOf((*MockConfiguration)(nil).GrantTypeTokenExchangeSupported))
+}
+
+// Insecure mocks base method.
+func (m *MockConfiguration) Insecure() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Insecure")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// Insecure indicates an expected call of Insecure.
+func (mr *MockConfigurationMockRecorder) Insecure() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insecure", reflect.TypeOf((*MockConfiguration)(nil).Insecure))
+}
+
+// IntrospectionAuthMethodPrivateKeyJWTSupported mocks base method.
+func (m *MockConfiguration) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IntrospectionAuthMethodPrivateKeyJWTSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IntrospectionAuthMethodPrivateKeyJWTSupported indicates an expected call of IntrospectionAuthMethodPrivateKeyJWTSupported.
+func (mr *MockConfigurationMockRecorder) IntrospectionAuthMethodPrivateKeyJWTSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionAuthMethodPrivateKeyJWTSupported))
+}
+
+// IntrospectionEndpoint mocks base method.
+func (m *MockConfiguration) IntrospectionEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IntrospectionEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// IntrospectionEndpoint indicates an expected call of IntrospectionEndpoint.
+func (mr *MockConfigurationMockRecorder) IntrospectionEndpoint() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpoint", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpoint))
+}
+
+// IntrospectionEndpointSigningAlgorithmsSupported mocks base method.
+func (m *MockConfiguration) IntrospectionEndpointSigningAlgorithmsSupported() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IntrospectionEndpointSigningAlgorithmsSupported")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// IntrospectionEndpointSigningAlgorithmsSupported indicates an expected call of IntrospectionEndpointSigningAlgorithmsSupported.
+func (mr *MockConfigurationMockRecorder) IntrospectionEndpointSigningAlgorithmsSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntrospectionEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).IntrospectionEndpointSigningAlgorithmsSupported))
+}
+
+// IssuerFromRequest mocks base method.
+func (m *MockConfiguration) IssuerFromRequest(arg0 *http.Request) string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IssuerFromRequest", arg0)
ret0, _ := ret[0].(string)
return ret0
}
-// Issuer indicates an expected call of Issuer
-func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
+// IssuerFromRequest indicates an expected call of IssuerFromRequest.
+func (mr *MockConfigurationMockRecorder) IssuerFromRequest(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssuerFromRequest", reflect.TypeOf((*MockConfiguration)(nil).IssuerFromRequest), arg0)
}
-// KeysEndpoint mocks base method
-func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
+// KeysEndpoint mocks base method.
+func (m *MockConfiguration) KeysEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeysEndpoint")
- ret0, _ := ret[0].(op.Endpoint)
+ ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
-// KeysEndpoint indicates an expected call of KeysEndpoint
+// KeysEndpoint indicates an expected call of KeysEndpoint.
func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
}
-// Port mocks base method
-func (m *MockConfiguration) Port() string {
+// RequestObjectSigningAlgorithmsSupported mocks base method.
+func (m *MockConfiguration) RequestObjectSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Port")
- ret0, _ := ret[0].(string)
+ ret := m.ctrl.Call(m, "RequestObjectSigningAlgorithmsSupported")
+ ret0, _ := ret[0].([]string)
return ret0
}
-// Port indicates an expected call of Port
-func (mr *MockConfigurationMockRecorder) Port() *gomock.Call {
+// RequestObjectSigningAlgorithmsSupported indicates an expected call of RequestObjectSigningAlgorithmsSupported.
+func (mr *MockConfigurationMockRecorder) RequestObjectSigningAlgorithmsSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockConfiguration)(nil).Port))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSigningAlgorithmsSupported))
}
-// TokenEndpoint mocks base method
-func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
+// RequestObjectSupported mocks base method.
+func (m *MockConfiguration) RequestObjectSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RequestObjectSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// RequestObjectSupported indicates an expected call of RequestObjectSupported.
+func (mr *MockConfigurationMockRecorder) RequestObjectSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestObjectSupported", reflect.TypeOf((*MockConfiguration)(nil).RequestObjectSupported))
+}
+
+// RevocationAuthMethodPrivateKeyJWTSupported mocks base method.
+func (m *MockConfiguration) RevocationAuthMethodPrivateKeyJWTSupported() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevocationAuthMethodPrivateKeyJWTSupported")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// RevocationAuthMethodPrivateKeyJWTSupported indicates an expected call of RevocationAuthMethodPrivateKeyJWTSupported.
+func (mr *MockConfigurationMockRecorder) RevocationAuthMethodPrivateKeyJWTSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationAuthMethodPrivateKeyJWTSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationAuthMethodPrivateKeyJWTSupported))
+}
+
+// RevocationEndpoint mocks base method.
+func (m *MockConfiguration) RevocationEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevocationEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// RevocationEndpoint indicates an expected call of RevocationEndpoint.
+func (mr *MockConfigurationMockRecorder) RevocationEndpoint() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpoint))
+}
+
+// RevocationEndpointSigningAlgorithmsSupported mocks base method.
+func (m *MockConfiguration) RevocationEndpointSigningAlgorithmsSupported() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevocationEndpointSigningAlgorithmsSupported")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// RevocationEndpointSigningAlgorithmsSupported indicates an expected call of RevocationEndpointSigningAlgorithmsSupported.
+func (mr *MockConfigurationMockRecorder) RevocationEndpointSigningAlgorithmsSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevocationEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).RevocationEndpointSigningAlgorithmsSupported))
+}
+
+// SupportedUILocales mocks base method.
+func (m *MockConfiguration) SupportedUILocales() []language.Tag {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SupportedUILocales")
+ ret0, _ := ret[0].([]language.Tag)
+ return ret0
+}
+
+// SupportedUILocales indicates an expected call of SupportedUILocales.
+func (mr *MockConfigurationMockRecorder) SupportedUILocales() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SupportedUILocales", reflect.TypeOf((*MockConfiguration)(nil).SupportedUILocales))
+}
+
+// TokenEndpoint mocks base method.
+func (m *MockConfiguration) TokenEndpoint() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TokenEndpoint")
- ret0, _ := ret[0].(op.Endpoint)
+ ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
-// TokenEndpoint indicates an expected call of TokenEndpoint
+// TokenEndpoint indicates an expected call of TokenEndpoint.
func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
}
-// UserinfoEndpoint mocks base method
-func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
+// TokenEndpointSigningAlgorithmsSupported mocks base method.
+func (m *MockConfiguration) TokenEndpointSigningAlgorithmsSupported() []string {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "UserinfoEndpoint")
- ret0, _ := ret[0].(op.Endpoint)
+ ret := m.ctrl.Call(m, "TokenEndpointSigningAlgorithmsSupported")
+ ret0, _ := ret[0].([]string)
return ret0
}
-// UserinfoEndpoint indicates an expected call of UserinfoEndpoint
+// TokenEndpointSigningAlgorithmsSupported indicates an expected call of TokenEndpointSigningAlgorithmsSupported.
+func (mr *MockConfigurationMockRecorder) TokenEndpointSigningAlgorithmsSupported() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpointSigningAlgorithmsSupported", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpointSigningAlgorithmsSupported))
+}
+
+// UserinfoEndpoint mocks base method.
+func (m *MockConfiguration) UserinfoEndpoint() *op.Endpoint {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "UserinfoEndpoint")
+ ret0, _ := ret[0].(*op.Endpoint)
+ return ret0
+}
+
+// UserinfoEndpoint indicates an expected call of UserinfoEndpoint.
func (mr *MockConfigurationMockRecorder) UserinfoEndpoint() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserinfoEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserinfoEndpoint))
diff --git a/pkg/op/mock/discovery.mock.go b/pkg/op/mock/discovery.mock.go
new file mode 100644
index 0000000..c85f91b
--- /dev/null
+++ b/pkg/op/mock/discovery.mock.go
@@ -0,0 +1,51 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: DiscoverStorage)
+
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+ context "context"
+ reflect "reflect"
+
+ jose "github.com/go-jose/go-jose/v4"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockDiscoverStorage is a mock of DiscoverStorage interface.
+type MockDiscoverStorage struct {
+ ctrl *gomock.Controller
+ recorder *MockDiscoverStorageMockRecorder
+}
+
+// MockDiscoverStorageMockRecorder is the mock recorder for MockDiscoverStorage.
+type MockDiscoverStorageMockRecorder struct {
+ mock *MockDiscoverStorage
+}
+
+// NewMockDiscoverStorage creates a new mock instance.
+func NewMockDiscoverStorage(ctrl *gomock.Controller) *MockDiscoverStorage {
+ mock := &MockDiscoverStorage{ctrl: ctrl}
+ mock.recorder = &MockDiscoverStorageMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockDiscoverStorage) EXPECT() *MockDiscoverStorageMockRecorder {
+ return m.recorder
+}
+
+// SignatureAlgorithms mocks base method.
+func (m *MockDiscoverStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
+ ret0, _ := ret[0].([]jose.SignatureAlgorithm)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
+func (mr *MockDiscoverStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockDiscoverStorage)(nil).SignatureAlgorithms), arg0)
+}
diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go
index beb3132..3d58ab7 100644
--- a/pkg/op/mock/generate.go
+++ b/pkg/op/mock/generate.go
@@ -1,7 +1,11 @@
package mock
-//go:generate mockgen -package mock -destination ./storage.mock.go github.com/caos/oidc/pkg/op Storage
-//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer
-//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
-//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
-//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer
+//go:generate go install github.com/golang/mock/mockgen@v1.6.0
+//go:generate mockgen -package mock -destination ./storage.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Storage
+//go:generate mockgen -package mock -destination ./authorizer.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Authorizer
+//go:generate mockgen -package mock -destination ./client.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Client
+//go:generate mockgen -package mock -destination ./glob.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op HasRedirectGlobs
+//go:generate mockgen -package mock -destination ./configuration.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op Configuration
+//go:generate mockgen -package mock -destination ./discovery.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op DiscoverStorage
+//go:generate mockgen -package mock -destination ./signer.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op SigningKey,Key
+//go:generate mockgen -package mock -destination ./key.mock.go git.christmann.info/LARA/zitadel-oidc/v3/pkg/op KeyProvider
diff --git a/pkg/op/mock/glob.go b/pkg/op/mock/glob.go
new file mode 100644
index 0000000..8149c8f
--- /dev/null
+++ b/pkg/op/mock/glob.go
@@ -0,0 +1,24 @@
+package mock
+
+import (
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
+)
+
+func NewHasRedirectGlobs(t *testing.T) op.HasRedirectGlobs {
+ return NewMockHasRedirectGlobs(gomock.NewController(t))
+}
+
+func NewHasRedirectGlobsWithConfig(t *testing.T, uri []string, appType op.ApplicationType, responseTypes []oidc.ResponseType, devMode bool) op.HasRedirectGlobs {
+ c := NewHasRedirectGlobs(t)
+ m := c.(*MockHasRedirectGlobs)
+ m.EXPECT().RedirectURIs().AnyTimes().Return(uri)
+ m.EXPECT().RedirectURIGlobs().AnyTimes().Return(uri)
+ m.EXPECT().ApplicationType().AnyTimes().Return(appType)
+ m.EXPECT().ResponseTypes().AnyTimes().Return(responseTypes)
+ m.EXPECT().DevMode().AnyTimes().Return(devMode)
+ return c
+}
diff --git a/pkg/op/mock/glob.mock.go b/pkg/op/mock/glob.mock.go
new file mode 100644
index 0000000..ebdc333
--- /dev/null
+++ b/pkg/op/mock/glob.mock.go
@@ -0,0 +1,289 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: HasRedirectGlobs)
+
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+ reflect "reflect"
+ time "time"
+
+ oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockHasRedirectGlobs is a mock of HasRedirectGlobs interface.
+type MockHasRedirectGlobs struct {
+ ctrl *gomock.Controller
+ recorder *MockHasRedirectGlobsMockRecorder
+}
+
+// MockHasRedirectGlobsMockRecorder is the mock recorder for MockHasRedirectGlobs.
+type MockHasRedirectGlobsMockRecorder struct {
+ mock *MockHasRedirectGlobs
+}
+
+// NewMockHasRedirectGlobs creates a new mock instance.
+func NewMockHasRedirectGlobs(ctrl *gomock.Controller) *MockHasRedirectGlobs {
+ mock := &MockHasRedirectGlobs{ctrl: ctrl}
+ mock.recorder = &MockHasRedirectGlobsMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockHasRedirectGlobs) EXPECT() *MockHasRedirectGlobsMockRecorder {
+ return m.recorder
+}
+
+// AccessTokenType mocks base method.
+func (m *MockHasRedirectGlobs) AccessTokenType() op.AccessTokenType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AccessTokenType")
+ ret0, _ := ret[0].(op.AccessTokenType)
+ return ret0
+}
+
+// AccessTokenType indicates an expected call of AccessTokenType.
+func (mr *MockHasRedirectGlobsMockRecorder) AccessTokenType() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AccessTokenType))
+}
+
+// ApplicationType mocks base method.
+func (m *MockHasRedirectGlobs) ApplicationType() op.ApplicationType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ApplicationType")
+ ret0, _ := ret[0].(op.ApplicationType)
+ return ret0
+}
+
+// ApplicationType indicates an expected call of ApplicationType.
+func (mr *MockHasRedirectGlobsMockRecorder) ApplicationType() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ApplicationType))
+}
+
+// AuthMethod mocks base method.
+func (m *MockHasRedirectGlobs) AuthMethod() oidc.AuthMethod {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "AuthMethod")
+ ret0, _ := ret[0].(oidc.AuthMethod)
+ return ret0
+}
+
+// AuthMethod indicates an expected call of AuthMethod.
+func (mr *MockHasRedirectGlobsMockRecorder) AuthMethod() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethod", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AuthMethod))
+}
+
+// ClockSkew mocks base method.
+func (m *MockHasRedirectGlobs) ClockSkew() time.Duration {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ClockSkew")
+ ret0, _ := ret[0].(time.Duration)
+ return ret0
+}
+
+// ClockSkew indicates an expected call of ClockSkew.
+func (mr *MockHasRedirectGlobsMockRecorder) ClockSkew() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClockSkew", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ClockSkew))
+}
+
+// DevMode mocks base method.
+func (m *MockHasRedirectGlobs) DevMode() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DevMode")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// DevMode indicates an expected call of DevMode.
+func (mr *MockHasRedirectGlobsMockRecorder) DevMode() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DevMode", reflect.TypeOf((*MockHasRedirectGlobs)(nil).DevMode))
+}
+
+// GetID mocks base method.
+func (m *MockHasRedirectGlobs) GetID() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetID")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// GetID indicates an expected call of GetID.
+func (mr *MockHasRedirectGlobsMockRecorder) GetID() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GetID))
+}
+
+// GrantTypes mocks base method.
+func (m *MockHasRedirectGlobs) GrantTypes() []oidc.GrantType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GrantTypes")
+ ret0, _ := ret[0].([]oidc.GrantType)
+ return ret0
+}
+
+// GrantTypes indicates an expected call of GrantTypes.
+func (mr *MockHasRedirectGlobsMockRecorder) GrantTypes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GrantTypes))
+}
+
+// IDTokenLifetime mocks base method.
+func (m *MockHasRedirectGlobs) IDTokenLifetime() time.Duration {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IDTokenLifetime")
+ ret0, _ := ret[0].(time.Duration)
+ return ret0
+}
+
+// IDTokenLifetime indicates an expected call of IDTokenLifetime.
+func (mr *MockHasRedirectGlobsMockRecorder) IDTokenLifetime() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenLifetime))
+}
+
+// IDTokenUserinfoClaimsAssertion mocks base method.
+func (m *MockHasRedirectGlobs) IDTokenUserinfoClaimsAssertion() bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IDTokenUserinfoClaimsAssertion")
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IDTokenUserinfoClaimsAssertion indicates an expected call of IDTokenUserinfoClaimsAssertion.
+func (mr *MockHasRedirectGlobsMockRecorder) IDTokenUserinfoClaimsAssertion() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenUserinfoClaimsAssertion", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenUserinfoClaimsAssertion))
+}
+
+// IsScopeAllowed mocks base method.
+func (m *MockHasRedirectGlobs) IsScopeAllowed(arg0 string) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "IsScopeAllowed", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// IsScopeAllowed indicates an expected call of IsScopeAllowed.
+func (mr *MockHasRedirectGlobsMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IsScopeAllowed), arg0)
+}
+
+// LoginURL mocks base method.
+func (m *MockHasRedirectGlobs) LoginURL(arg0 string) string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoginURL", arg0)
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// LoginURL indicates an expected call of LoginURL.
+func (mr *MockHasRedirectGlobsMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockHasRedirectGlobs)(nil).LoginURL), arg0)
+}
+
+// PostLogoutRedirectURIGlobs mocks base method.
+func (m *MockHasRedirectGlobs) PostLogoutRedirectURIGlobs() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "PostLogoutRedirectURIGlobs")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// PostLogoutRedirectURIGlobs indicates an expected call of PostLogoutRedirectURIGlobs.
+func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIGlobs() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIGlobs))
+}
+
+// PostLogoutRedirectURIs mocks base method.
+func (m *MockHasRedirectGlobs) PostLogoutRedirectURIs() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "PostLogoutRedirectURIs")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs.
+func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIs() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIs))
+}
+
+// RedirectURIGlobs mocks base method.
+func (m *MockHasRedirectGlobs) RedirectURIGlobs() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RedirectURIGlobs")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// RedirectURIGlobs indicates an expected call of RedirectURIGlobs.
+func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIGlobs() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIGlobs))
+}
+
+// RedirectURIs mocks base method.
+func (m *MockHasRedirectGlobs) RedirectURIs() []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RedirectURIs")
+ ret0, _ := ret[0].([]string)
+ return ret0
+}
+
+// RedirectURIs indicates an expected call of RedirectURIs.
+func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIs() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIs))
+}
+
+// ResponseTypes mocks base method.
+func (m *MockHasRedirectGlobs) ResponseTypes() []oidc.ResponseType {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ResponseTypes")
+ ret0, _ := ret[0].([]oidc.ResponseType)
+ return ret0
+}
+
+// ResponseTypes indicates an expected call of ResponseTypes.
+func (mr *MockHasRedirectGlobsMockRecorder) ResponseTypes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ResponseTypes))
+}
+
+// RestrictAdditionalAccessTokenScopes mocks base method.
+func (m *MockHasRedirectGlobs) RestrictAdditionalAccessTokenScopes() func([]string) []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes")
+ ret0, _ := ret[0].(func([]string) []string)
+ return ret0
+}
+
+// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes.
+func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalAccessTokenScopes))
+}
+
+// RestrictAdditionalIdTokenScopes mocks base method.
+func (m *MockHasRedirectGlobs) RestrictAdditionalIdTokenScopes() func([]string) []string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes")
+ ret0, _ := ret[0].(func([]string) []string)
+ return ret0
+}
+
+// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes.
+func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalIdTokenScopes))
+}
diff --git a/pkg/op/mock/key.mock.go b/pkg/op/mock/key.mock.go
new file mode 100644
index 0000000..d9ee857
--- /dev/null
+++ b/pkg/op/mock/key.mock.go
@@ -0,0 +1,51 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: KeyProvider)
+
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+ context "context"
+ reflect "reflect"
+
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockKeyProvider is a mock of KeyProvider interface.
+type MockKeyProvider struct {
+ ctrl *gomock.Controller
+ recorder *MockKeyProviderMockRecorder
+}
+
+// MockKeyProviderMockRecorder is the mock recorder for MockKeyProvider.
+type MockKeyProviderMockRecorder struct {
+ mock *MockKeyProvider
+}
+
+// NewMockKeyProvider creates a new mock instance.
+func NewMockKeyProvider(ctrl *gomock.Controller) *MockKeyProvider {
+ mock := &MockKeyProvider{ctrl: ctrl}
+ mock.recorder = &MockKeyProviderMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockKeyProvider) EXPECT() *MockKeyProviderMockRecorder {
+ return m.recorder
+}
+
+// KeySet mocks base method.
+func (m *MockKeyProvider) KeySet(arg0 context.Context) ([]op.Key, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "KeySet", arg0)
+ ret0, _ := ret[0].([]op.Key)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// KeySet indicates an expected call of KeySet.
+func (mr *MockKeyProviderMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockKeyProvider)(nil).KeySet), arg0)
+}
diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go
index a7d909c..751ce60 100644
--- a/pkg/op/mock/signer.mock.go
+++ b/pkg/op/mock/signer.mock.go
@@ -1,94 +1,156 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/caos/oidc/pkg/op (interfaces: Signer)
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: SigningKey,Key)
// Package mock is a generated GoMock package.
package mock
import (
- context "context"
- oidc "github.com/caos/oidc/pkg/oidc"
- gomock "github.com/golang/mock/gomock"
- jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
+
+ jose "github.com/go-jose/go-jose/v4"
+ gomock "github.com/golang/mock/gomock"
)
-// MockSigner is a mock of Signer interface
-type MockSigner struct {
+// MockSigningKey is a mock of SigningKey interface.
+type MockSigningKey struct {
ctrl *gomock.Controller
- recorder *MockSignerMockRecorder
+ recorder *MockSigningKeyMockRecorder
}
-// MockSignerMockRecorder is the mock recorder for MockSigner
-type MockSignerMockRecorder struct {
- mock *MockSigner
+// MockSigningKeyMockRecorder is the mock recorder for MockSigningKey.
+type MockSigningKeyMockRecorder struct {
+ mock *MockSigningKey
}
-// NewMockSigner creates a new mock instance
-func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
- mock := &MockSigner{ctrl: ctrl}
- mock.recorder = &MockSignerMockRecorder{mock}
+// NewMockSigningKey creates a new mock instance.
+func NewMockSigningKey(ctrl *gomock.Controller) *MockSigningKey {
+ mock := &MockSigningKey{ctrl: ctrl}
+ mock.recorder = &MockSigningKeyMockRecorder{mock}
return mock
}
-// EXPECT returns an object that allows the caller to indicate expected use
-func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockSigningKey) EXPECT() *MockSigningKeyMockRecorder {
return m.recorder
}
-// Health mocks base method
-func (m *MockSigner) Health(arg0 context.Context) error {
+// ID mocks base method.
+func (m *MockSigningKey) ID() string {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Health", arg0)
- ret0, _ := ret[0].(error)
+ ret := m.ctrl.Call(m, "ID")
+ ret0, _ := ret[0].(string)
return ret0
}
-// Health indicates an expected call of Health
-func (mr *MockSignerMockRecorder) Health(arg0 interface{}) *gomock.Call {
+// ID indicates an expected call of ID.
+func (mr *MockSigningKeyMockRecorder) ID() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockSigner)(nil).Health), arg0)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockSigningKey)(nil).ID))
}
-// SignAccessToken mocks base method
-func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
+// Key mocks base method.
+func (m *MockSigningKey) Key() interface{} {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "SignAccessToken", arg0)
- ret0, _ := ret[0].(string)
- ret1, _ := ret[1].(error)
- return ret0, ret1
+ ret := m.ctrl.Call(m, "Key")
+ ret0, _ := ret[0].(interface{})
+ return ret0
}
-// SignAccessToken indicates an expected call of SignAccessToken
-func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
+// Key indicates an expected call of Key.
+func (mr *MockSigningKeyMockRecorder) Key() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockSigningKey)(nil).Key))
}
-// SignIDToken mocks base method
-func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "SignIDToken", arg0)
- ret0, _ := ret[0].(string)
- ret1, _ := ret[1].(error)
- return ret0, ret1
-}
-
-// SignIDToken indicates an expected call of SignIDToken
-func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
-}
-
-// SignatureAlgorithm mocks base method
-func (m *MockSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
+// SignatureAlgorithm mocks base method.
+func (m *MockSigningKey) SignatureAlgorithm() jose.SignatureAlgorithm {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SignatureAlgorithm")
ret0, _ := ret[0].(jose.SignatureAlgorithm)
return ret0
}
-// SignatureAlgorithm indicates an expected call of SignatureAlgorithm
-func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
+// SignatureAlgorithm indicates an expected call of SignatureAlgorithm.
+func (mr *MockSigningKeyMockRecorder) SignatureAlgorithm() *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigningKey)(nil).SignatureAlgorithm))
+}
+
+// MockKey is a mock of Key interface.
+type MockKey struct {
+ ctrl *gomock.Controller
+ recorder *MockKeyMockRecorder
+}
+
+// MockKeyMockRecorder is the mock recorder for MockKey.
+type MockKeyMockRecorder struct {
+ mock *MockKey
+}
+
+// NewMockKey creates a new mock instance.
+func NewMockKey(ctrl *gomock.Controller) *MockKey {
+ mock := &MockKey{ctrl: ctrl}
+ mock.recorder = &MockKeyMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockKey) EXPECT() *MockKeyMockRecorder {
+ return m.recorder
+}
+
+// Algorithm mocks base method.
+func (m *MockKey) Algorithm() jose.SignatureAlgorithm {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Algorithm")
+ ret0, _ := ret[0].(jose.SignatureAlgorithm)
+ return ret0
+}
+
+// Algorithm indicates an expected call of Algorithm.
+func (mr *MockKeyMockRecorder) Algorithm() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockKey)(nil).Algorithm))
+}
+
+// ID mocks base method.
+func (m *MockKey) ID() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ID")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// ID indicates an expected call of ID.
+func (mr *MockKeyMockRecorder) ID() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockKey)(nil).ID))
+}
+
+// Key mocks base method.
+func (m *MockKey) Key() interface{} {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Key")
+ ret0, _ := ret[0].(interface{})
+ return ret0
+}
+
+// Key indicates an expected call of Key.
+func (mr *MockKeyMockRecorder) Key() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockKey)(nil).Key))
+}
+
+// Use mocks base method.
+func (m *MockKey) Use() string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Use")
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// Use indicates an expected call of Use.
+func (mr *MockKeyMockRecorder) Use() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Use", reflect.TypeOf((*MockKey)(nil).Use))
}
diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go
index ac8ba27..0df9830 100644
--- a/pkg/op/mock/storage.mock.go
+++ b/pkg/op/mock/storage.mock.go
@@ -1,43 +1,44 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/caos/oidc/pkg/op (interfaces: Storage)
+// Source: git.christmann.info/LARA/zitadel-oidc/v3/pkg/op (interfaces: Storage)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
- oidc "github.com/caos/oidc/pkg/oidc"
- op "github.com/caos/oidc/pkg/op"
- gomock "github.com/golang/mock/gomock"
- jose "gopkg.in/square/go-jose.v2"
reflect "reflect"
time "time"
+
+ oidc "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ op "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ jose "github.com/go-jose/go-jose/v4"
+ gomock "github.com/golang/mock/gomock"
)
-// MockStorage is a mock of Storage interface
+// MockStorage is a mock of Storage interface.
type MockStorage struct {
ctrl *gomock.Controller
recorder *MockStorageMockRecorder
}
-// MockStorageMockRecorder is the mock recorder for MockStorage
+// MockStorageMockRecorder is the mock recorder for MockStorage.
type MockStorageMockRecorder struct {
mock *MockStorage
}
-// NewMockStorage creates a new mock instance
+// NewMockStorage creates a new mock instance.
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
mock := &MockStorage{ctrl: ctrl}
mock.recorder = &MockStorageMockRecorder{mock}
return mock
}
-// EXPECT returns an object that allows the caller to indicate expected use
+// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
return m.recorder
}
-// AuthRequestByCode mocks base method
+// AuthRequestByCode mocks base method.
func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByCode", arg0, arg1)
@@ -46,13 +47,13 @@ func (m *MockStorage) AuthRequestByCode(arg0 context.Context, arg1 string) (op.A
return ret0, ret1
}
-// AuthRequestByCode indicates an expected call of AuthRequestByCode
+// AuthRequestByCode indicates an expected call of AuthRequestByCode.
func (mr *MockStorageMockRecorder) AuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByCode", reflect.TypeOf((*MockStorage)(nil).AuthRequestByCode), arg0, arg1)
}
-// AuthRequestByID mocks base method
+// AuthRequestByID mocks base method.
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
@@ -61,13 +62,13 @@ func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.Aut
return ret0, ret1
}
-// AuthRequestByID indicates an expected call of AuthRequestByID
+// AuthRequestByID indicates an expected call of AuthRequestByID.
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
}
-// AuthorizeClientIDSecret mocks base method
+// AuthorizeClientIDSecret mocks base method.
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
@@ -75,13 +76,46 @@ func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 s
return ret0
}
-// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
+// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret.
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
}
-// CreateAuthRequest mocks base method
+// CreateAccessAndRefreshTokens mocks base method.
+func (m *MockStorage) CreateAccessAndRefreshTokens(arg0 context.Context, arg1 op.TokenRequest, arg2 string) (string, string, time.Time, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CreateAccessAndRefreshTokens", arg0, arg1, arg2)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(string)
+ ret2, _ := ret[2].(time.Time)
+ ret3, _ := ret[3].(error)
+ return ret0, ret1, ret2, ret3
+}
+
+// CreateAccessAndRefreshTokens indicates an expected call of CreateAccessAndRefreshTokens.
+func (mr *MockStorageMockRecorder) CreateAccessAndRefreshTokens(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAccessAndRefreshTokens", reflect.TypeOf((*MockStorage)(nil).CreateAccessAndRefreshTokens), arg0, arg1, arg2)
+}
+
+// CreateAccessToken mocks base method.
+func (m *MockStorage) CreateAccessToken(arg0 context.Context, arg1 op.TokenRequest) (string, time.Time, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CreateAccessToken", arg0, arg1)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(time.Time)
+ ret2, _ := ret[2].(error)
+ return ret0, ret1, ret2
+}
+
+// CreateAccessToken indicates an expected call of CreateAccessToken.
+func (mr *MockStorageMockRecorder) CreateAccessToken(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAccessToken", reflect.TypeOf((*MockStorage)(nil).CreateAccessToken), arg0, arg1)
+}
+
+// CreateAuthRequest mocks base method.
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest, arg2 string) (op.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1, arg2)
@@ -90,29 +124,13 @@ func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthReq
return ret0, ret1
}
-// CreateAuthRequest indicates an expected call of CreateAuthRequest
+// CreateAuthRequest indicates an expected call of CreateAuthRequest.
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1, arg2)
}
-// CreateToken mocks base method
-func (m *MockStorage) CreateToken(arg0 context.Context, arg1 op.AuthRequest) (string, time.Time, error) {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "CreateToken", arg0, arg1)
- ret0, _ := ret[0].(string)
- ret1, _ := ret[1].(time.Time)
- ret2, _ := ret[2].(error)
- return ret0, ret1, ret2
-}
-
-// CreateToken indicates an expected call of CreateToken
-func (mr *MockStorageMockRecorder) CreateToken(arg0, arg1 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateToken", reflect.TypeOf((*MockStorage)(nil).CreateToken), arg0, arg1)
-}
-
-// DeleteAuthRequest mocks base method
+// DeleteAuthRequest mocks base method.
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
@@ -120,13 +138,13 @@ func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error
return ret0
}
-// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
+// DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
}
-// GetClientByClientID mocks base method
+// GetClientByClientID mocks base method.
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
@@ -135,70 +153,59 @@ func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op
return ret0, ret1
}
-// GetClientByClientID indicates an expected call of GetClientByClientID
+// GetClientByClientID indicates an expected call of GetClientByClientID.
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
}
-// GetKeySet mocks base method
-func (m *MockStorage) GetKeySet(arg0 context.Context) (*jose.JSONWebKeySet, error) {
+// GetKeyByIDAndClientID mocks base method.
+func (m *MockStorage) GetKeyByIDAndClientID(arg0 context.Context, arg1, arg2 string) (*jose.JSONWebKey, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetKeySet", arg0)
- ret0, _ := ret[0].(*jose.JSONWebKeySet)
+ ret := m.ctrl.Call(m, "GetKeyByIDAndClientID", arg0, arg1, arg2)
+ ret0, _ := ret[0].(*jose.JSONWebKey)
ret1, _ := ret[1].(error)
return ret0, ret1
}
-// GetKeySet indicates an expected call of GetKeySet
-func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
+// GetKeyByIDAndClientID indicates an expected call of GetKeyByIDAndClientID.
+func (mr *MockStorageMockRecorder) GetKeyByIDAndClientID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeyByIDAndClientID", reflect.TypeOf((*MockStorage)(nil).GetKeyByIDAndClientID), arg0, arg1, arg2)
}
-// GetSigningKey mocks base method
-func (m *MockStorage) GetSigningKey(arg0 context.Context, arg1 chan<- jose.SigningKey, arg2 chan<- error, arg3 <-chan time.Time) {
+// GetPrivateClaimsFromScopes mocks base method.
+func (m *MockStorage) GetPrivateClaimsFromScopes(arg0 context.Context, arg1, arg2 string, arg3 []string) (map[string]interface{}, error) {
m.ctrl.T.Helper()
- m.ctrl.Call(m, "GetSigningKey", arg0, arg1, arg2, arg3)
-}
-
-// GetSigningKey indicates an expected call of GetSigningKey
-func (mr *MockStorageMockRecorder) GetSigningKey(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0, arg1, arg2, arg3)
-}
-
-// GetUserinfoFromScopes mocks base method
-func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 string, arg2 []string) (*oidc.Userinfo, error) {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1, arg2)
- ret0, _ := ret[0].(*oidc.Userinfo)
+ ret := m.ctrl.Call(m, "GetPrivateClaimsFromScopes", arg0, arg1, arg2, arg3)
+ ret0, _ := ret[0].(map[string]interface{})
ret1, _ := ret[1].(error)
return ret0, ret1
}
-// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
-func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
+// GetPrivateClaimsFromScopes indicates an expected call of GetPrivateClaimsFromScopes.
+func (mr *MockStorageMockRecorder) GetPrivateClaimsFromScopes(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1, arg2)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateClaimsFromScopes", reflect.TypeOf((*MockStorage)(nil).GetPrivateClaimsFromScopes), arg0, arg1, arg2, arg3)
}
-// GetUserinfoFromToken mocks base method
-func (m *MockStorage) GetUserinfoFromToken(arg0 context.Context, arg1 string) (*oidc.Userinfo, error) {
+// GetRefreshTokenInfo mocks base method.
+func (m *MockStorage) GetRefreshTokenInfo(arg0 context.Context, arg1, arg2 string) (string, string, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetUserinfoFromToken", arg0, arg1)
- ret0, _ := ret[0].(*oidc.Userinfo)
- ret1, _ := ret[1].(error)
- return ret0, ret1
+ ret := m.ctrl.Call(m, "GetRefreshTokenInfo", arg0, arg1, arg2)
+ ret0, _ := ret[0].(string)
+ ret1, _ := ret[1].(string)
+ ret2, _ := ret[2].(error)
+ return ret0, ret1, ret2
}
-// GetUserinfoFromToken indicates an expected call of GetUserinfoFromToken
-func (mr *MockStorageMockRecorder) GetUserinfoFromToken(arg0, arg1 interface{}) *gomock.Call {
+// GetRefreshTokenInfo indicates an expected call of GetRefreshTokenInfo.
+func (mr *MockStorageMockRecorder) GetRefreshTokenInfo(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromToken), arg0, arg1)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefreshTokenInfo", reflect.TypeOf((*MockStorage)(nil).GetRefreshTokenInfo), arg0, arg1, arg2)
}
-// Health mocks base method
+// Health mocks base method.
func (m *MockStorage) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
@@ -206,13 +213,42 @@ func (m *MockStorage) Health(arg0 context.Context) error {
return ret0
}
-// Health indicates an expected call of Health
+// Health indicates an expected call of Health.
func (mr *MockStorageMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockStorage)(nil).Health), arg0)
}
-// SaveAuthCode mocks base method
+// KeySet mocks base method.
+func (m *MockStorage) KeySet(arg0 context.Context) ([]op.Key, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "KeySet", arg0)
+ ret0, _ := ret[0].([]op.Key)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// KeySet indicates an expected call of KeySet.
+func (mr *MockStorageMockRecorder) KeySet(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeySet", reflect.TypeOf((*MockStorage)(nil).KeySet), arg0)
+}
+
+// RevokeToken mocks base method.
+func (m *MockStorage) RevokeToken(arg0 context.Context, arg1, arg2, arg3 string) *oidc.Error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2, arg3)
+ ret0, _ := ret[0].(*oidc.Error)
+ return ret0
+}
+
+// RevokeToken indicates an expected call of RevokeToken.
+func (mr *MockStorageMockRecorder) RevokeToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockStorage)(nil).RevokeToken), arg0, arg1, arg2, arg3)
+}
+
+// SaveAuthCode mocks base method.
func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveAuthCode", arg0, arg1, arg2)
@@ -220,27 +256,85 @@ func (m *MockStorage) SaveAuthCode(arg0 context.Context, arg1, arg2 string) erro
return ret0
}
-// SaveAuthCode indicates an expected call of SaveAuthCode
+// SaveAuthCode indicates an expected call of SaveAuthCode.
func (mr *MockStorageMockRecorder) SaveAuthCode(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthCode", reflect.TypeOf((*MockStorage)(nil).SaveAuthCode), arg0, arg1, arg2)
}
-// SaveNewKeyPair mocks base method
-func (m *MockStorage) SaveNewKeyPair(arg0 context.Context) error {
+// SetIntrospectionFromToken mocks base method.
+func (m *MockStorage) SetIntrospectionFromToken(arg0 context.Context, arg1 *oidc.IntrospectionResponse, arg2, arg3, arg4 string) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "SaveNewKeyPair", arg0)
+ ret := m.ctrl.Call(m, "SetIntrospectionFromToken", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
return ret0
}
-// SaveNewKeyPair indicates an expected call of SaveNewKeyPair
-func (mr *MockStorageMockRecorder) SaveNewKeyPair(arg0 interface{}) *gomock.Call {
+// SetIntrospectionFromToken indicates an expected call of SetIntrospectionFromToken.
+func (mr *MockStorageMockRecorder) SetIntrospectionFromToken(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveNewKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveNewKeyPair), arg0)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntrospectionFromToken", reflect.TypeOf((*MockStorage)(nil).SetIntrospectionFromToken), arg0, arg1, arg2, arg3, arg4)
}
-// TerminateSession mocks base method
+// SetUserinfoFromScopes mocks base method.
+func (m *MockStorage) SetUserinfoFromScopes(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3 string, arg4 []string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SetUserinfoFromScopes", arg0, arg1, arg2, arg3, arg4)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SetUserinfoFromScopes indicates an expected call of SetUserinfoFromScopes.
+func (mr *MockStorageMockRecorder) SetUserinfoFromScopes(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromScopes), arg0, arg1, arg2, arg3, arg4)
+}
+
+// SetUserinfoFromToken mocks base method.
+func (m *MockStorage) SetUserinfoFromToken(arg0 context.Context, arg1 *oidc.UserInfo, arg2, arg3, arg4 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SetUserinfoFromToken", arg0, arg1, arg2, arg3, arg4)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SetUserinfoFromToken indicates an expected call of SetUserinfoFromToken.
+func (mr *MockStorageMockRecorder) SetUserinfoFromToken(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserinfoFromToken", reflect.TypeOf((*MockStorage)(nil).SetUserinfoFromToken), arg0, arg1, arg2, arg3, arg4)
+}
+
+// SignatureAlgorithms mocks base method.
+func (m *MockStorage) SignatureAlgorithms(arg0 context.Context) ([]jose.SignatureAlgorithm, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SignatureAlgorithms", arg0)
+ ret0, _ := ret[0].([]jose.SignatureAlgorithm)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// SignatureAlgorithms indicates an expected call of SignatureAlgorithms.
+func (mr *MockStorageMockRecorder) SignatureAlgorithms(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithms", reflect.TypeOf((*MockStorage)(nil).SignatureAlgorithms), arg0)
+}
+
+// SigningKey mocks base method.
+func (m *MockStorage) SigningKey(arg0 context.Context) (op.SigningKey, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SigningKey", arg0)
+ ret0, _ := ret[0].(op.SigningKey)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// SigningKey indicates an expected call of SigningKey.
+func (mr *MockStorageMockRecorder) SigningKey(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SigningKey", reflect.TypeOf((*MockStorage)(nil).SigningKey), arg0)
+}
+
+// TerminateSession mocks base method.
func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TerminateSession", arg0, arg1, arg2)
@@ -248,8 +342,38 @@ func (m *MockStorage) TerminateSession(arg0 context.Context, arg1, arg2 string)
return ret0
}
-// TerminateSession indicates an expected call of TerminateSession
+// TerminateSession indicates an expected call of TerminateSession.
func (mr *MockStorageMockRecorder) TerminateSession(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TerminateSession", reflect.TypeOf((*MockStorage)(nil).TerminateSession), arg0, arg1, arg2)
}
+
+// TokenRequestByRefreshToken mocks base method.
+func (m *MockStorage) TokenRequestByRefreshToken(arg0 context.Context, arg1 string) (op.RefreshTokenRequest, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "TokenRequestByRefreshToken", arg0, arg1)
+ ret0, _ := ret[0].(op.RefreshTokenRequest)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// TokenRequestByRefreshToken indicates an expected call of TokenRequestByRefreshToken.
+func (mr *MockStorageMockRecorder) TokenRequestByRefreshToken(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenRequestByRefreshToken", reflect.TypeOf((*MockStorage)(nil).TokenRequestByRefreshToken), arg0, arg1)
+}
+
+// ValidateJWTProfileScopes mocks base method.
+func (m *MockStorage) ValidateJWTProfileScopes(arg0 context.Context, arg1 string, arg2 []string) ([]string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ValidateJWTProfileScopes", arg0, arg1, arg2)
+ ret0, _ := ret[0].([]string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// ValidateJWTProfileScopes indicates an expected call of ValidateJWTProfileScopes.
+func (mr *MockStorageMockRecorder) ValidateJWTProfileScopes(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateJWTProfileScopes", reflect.TypeOf((*MockStorage)(nil).ValidateJWTProfileScopes), arg0, arg1, arg2)
+}
diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go
index c9c63c6..96e08a9 100644
--- a/pkg/op/mock/storage.mock.impl.go
+++ b/pkg/op/mock/storage.mock.impl.go
@@ -6,11 +6,10 @@ import (
"testing"
"time"
- "gopkg.in/square/go-jose.v2"
-
"github.com/golang/mock/gomock"
- "github.com/caos/oidc/pkg/op"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
)
func NewStorage(t *testing.T) op.Storage {
@@ -37,20 +36,15 @@ func NewMockStorageAny(t *testing.T) op.Storage {
return m
}
-func NewMockStorageSigningKeyError(t *testing.T) op.Storage {
+func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
m := NewStorage(t)
- ExpectSigningKeyError(m)
+ //ExpectSigningKeyInvalid(m)
return m
}
-func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
- m := NewStorage(t)
- ExpectSigningKeyInvalid(m)
- return m
-}
func NewMockStorageSigningKey(t *testing.T) op.Storage {
m := NewStorage(t)
- ExpectSigningKey(m)
+ //ExpectSigningKey(m)
return m
}
@@ -64,58 +58,38 @@ func ExpectValidClientID(s op.Storage) {
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, id string) (op.Client, error) {
var appType op.ApplicationType
- var authMethod op.AuthMethod
+ var authMethod oidc.AuthMethod
var accessTokenType op.AccessTokenType
+ var responseTypes []oidc.ResponseType
switch id {
case "web_client":
appType = op.ApplicationTypeWeb
- authMethod = op.AuthMethodBasic
+ authMethod = oidc.AuthMethodBasic
accessTokenType = op.AccessTokenTypeBearer
+ responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
case "native_client":
appType = op.ApplicationTypeNative
- authMethod = op.AuthMethodNone
+ authMethod = oidc.AuthMethodNone
accessTokenType = op.AccessTokenTypeBearer
+ responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
case "useragent_client":
appType = op.ApplicationTypeUserAgent
- authMethod = op.AuthMethodBasic
+ authMethod = oidc.AuthMethodBasic
accessTokenType = op.AccessTokenTypeJWT
+ responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken}
}
- return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
+ return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType, responseTypes: responseTypes}, nil
})
}
-func ExpectSigningKeyError(s op.Storage) {
- mockS := s.(*MockStorage)
- mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) {
- errCh <- errors.New("error")
- },
- )
-}
-
-func ExpectSigningKeyInvalid(s op.Storage) {
- mockS := s.(*MockStorage)
- mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) {
- keyCh <- jose.SigningKey{}
- },
- )
-}
-
-func ExpectSigningKey(s op.Storage) {
- mockS := s.(*MockStorage)
- mockS.EXPECT().GetSigningKey(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
- func(_ context.Context, keyCh chan<- jose.SigningKey, errCh chan<- error, _ <-chan bool) {
- keyCh <- jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}
- },
- )
-}
-
type ConfClient struct {
id string
appType op.ApplicationType
- authMethod op.AuthMethod
+ authMethod oidc.AuthMethod
accessTokenType op.AccessTokenType
+ responseTypes []oidc.ResponseType
+ grantTypes []oidc.GrantType
+ devMode bool
}
func (c *ConfClient) RedirectURIs() []string {
@@ -126,6 +100,7 @@ func (c *ConfClient) RedirectURIs() []string {
"custom://callback",
}
}
+
func (c *ConfClient) PostLogoutRedirectURIs() []string {
return []string{}
}
@@ -138,7 +113,7 @@ func (c *ConfClient) ApplicationType() op.ApplicationType {
return c.appType
}
-func (c *ConfClient) GetAuthMethod() op.AuthMethod {
+func (c *ConfClient) AuthMethod() oidc.AuthMethod {
return c.authMethod
}
@@ -147,11 +122,53 @@ func (c *ConfClient) GetID() string {
}
func (c *ConfClient) AccessTokenLifetime() time.Duration {
- return time.Duration(5 * time.Minute)
+ return 5 * time.Minute
}
+
func (c *ConfClient) IDTokenLifetime() time.Duration {
- return time.Duration(5 * time.Minute)
+ return 5 * time.Minute
}
+
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
return c.accessTokenType
}
+
+func (c *ConfClient) ResponseTypes() []oidc.ResponseType {
+ return c.responseTypes
+}
+
+func (c *ConfClient) GrantTypes() []oidc.GrantType {
+ return c.grantTypes
+}
+
+func (c *ConfClient) DevMode() bool {
+ return c.devMode
+}
+
+func (c *ConfClient) AllowedScopes() []string {
+ return nil
+}
+
+func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+func (c *ConfClient) IsScopeAllowed(scope string) bool {
+ return false
+}
+
+func (c *ConfClient) IDTokenUserinfoClaimsAssertion() bool {
+ return false
+}
+
+func (c *ConfClient) ClockSkew() time.Duration {
+ return 0
+}
diff --git a/pkg/op/op.go b/pkg/op/op.go
index 732a933..76c2c89 100644
--- a/pkg/op/op.go
+++ b/pkg/op/op.go
@@ -1,54 +1,672 @@
package op
import (
+ "context"
+ "fmt"
+ "log/slog"
"net/http"
+ "time"
- "github.com/gorilla/handlers"
- "github.com/gorilla/mux"
+ "github.com/go-chi/chi/v5"
+ jose "github.com/go-jose/go-jose/v4"
+ "github.com/rs/cors"
+ "github.com/zitadel/schema"
+ "go.opentelemetry.io/otel"
+ "golang.org/x/text/language"
- "github.com/caos/oidc/pkg/oidc"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
const (
- healthzEndpoint = "/healthz"
- readinessEndpoint = "/ready"
+ healthEndpoint = "/healthz"
+ readinessEndpoint = "/ready"
+ authCallbackPathSuffix = "/callback"
+ defaultAuthorizationEndpoint = "authorize"
+ defaultTokenEndpoint = "oauth/token"
+ defaultIntrospectEndpoint = "oauth/introspect"
+ defaultUserinfoEndpoint = "userinfo"
+ defaultRevocationEndpoint = "revoke"
+ defaultEndSessionEndpoint = "end_session"
+ defaultKeysEndpoint = "keys"
+ defaultDeviceAuthzEndpoint = "/device_authorization"
)
+var (
+ DefaultEndpoints = &Endpoints{
+ Authorization: NewEndpoint(defaultAuthorizationEndpoint),
+ Token: NewEndpoint(defaultTokenEndpoint),
+ Introspection: NewEndpoint(defaultIntrospectEndpoint),
+ Userinfo: NewEndpoint(defaultUserinfoEndpoint),
+ Revocation: NewEndpoint(defaultRevocationEndpoint),
+ EndSession: NewEndpoint(defaultEndSessionEndpoint),
+ JwksURI: NewEndpoint(defaultKeysEndpoint),
+ DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
+ }
+
+ DefaultSupportedClaims = []string{
+ "sub",
+ "aud",
+ "exp",
+ "iat",
+ "iss",
+ "auth_time",
+ "nonce",
+ "acr",
+ "amr",
+ "c_hash",
+ "at_hash",
+ "act",
+ "scopes",
+ "client_id",
+ "azp",
+ "preferred_username",
+ "name",
+ "family_name",
+ "given_name",
+ "locale",
+ "email",
+ "email_verified",
+ "phone_number",
+ "phone_number_verified",
+ }
+
+ defaultCORSOptions = cors.Options{
+ AllowCredentials: true,
+ AllowedHeaders: []string{
+ "Origin",
+ "Accept",
+ "Accept-Language",
+ "Authorization",
+ "Content-Type",
+ "X-Requested-With",
+ },
+ AllowedMethods: []string{
+ http.MethodGet,
+ http.MethodHead,
+ http.MethodPost,
+ },
+ ExposedHeaders: []string{
+ "Location",
+ "Content-Length",
+ },
+ AllowOriginFunc: func(_ string) bool {
+ return true
+ },
+ }
+)
+
+var tracer = otel.Tracer("github.com/zitadel/oidc/pkg/op")
+
type OpenIDProvider interface {
+ http.Handler
Configuration
- HandleReady(w http.ResponseWriter, r *http.Request)
- HandleDiscovery(w http.ResponseWriter, r *http.Request)
- HandleAuthorize(w http.ResponseWriter, r *http.Request)
- HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
- HandleExchange(w http.ResponseWriter, r *http.Request)
- HandleUserinfo(w http.ResponseWriter, r *http.Request)
- HandleEndSession(w http.ResponseWriter, r *http.Request)
- HandleKeys(w http.ResponseWriter, r *http.Request)
+ Storage() Storage
+ Decoder() httphelper.Decoder
+ Encoder() httphelper.Encoder
+ IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
+ AccessTokenVerifier(context.Context) *AccessTokenVerifier
+ Crypto() Crypto
+ DefaultLogoutRedirectURI() string
+ Probes() []ProbesFn
+ Logger() *slog.Logger
+
+ // Deprecated: Provider now implements http.Handler directly.
HttpHandler() http.Handler
}
-type HttpInterceptor func(http.HandlerFunc) http.HandlerFunc
+type HttpInterceptor func(http.Handler) http.Handler
-var DefaultInterceptor = func(h http.HandlerFunc) http.HandlerFunc {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- h(w, r)
- })
+type corsOptioner interface {
+ CORSOptions() *cors.Options
}
-func CreateRouter(o OpenIDProvider, h HttpInterceptor) *mux.Router {
- if h == nil {
- h = DefaultInterceptor
+func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
+ router := chi.NewRouter()
+ if co, ok := o.(corsOptioner); ok {
+ if opts := co.CORSOptions(); opts != nil {
+ router.Use(cors.New(*opts).Handler)
+ }
+ } else {
+ router.Use(cors.New(defaultCORSOptions).Handler)
}
- router := mux.NewRouter()
- router.Use(handlers.CORS())
- router.HandleFunc(healthzEndpoint, Healthz)
- router.HandleFunc(readinessEndpoint, o.HandleReady)
- router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
- router.HandleFunc(o.AuthorizationEndpoint().Relative(), h(o.HandleAuthorize))
- router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", h(o.HandleAuthorizeCallback))
- router.HandleFunc(o.TokenEndpoint().Relative(), h(o.HandleExchange))
- router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
- router.HandleFunc(o.EndSessionEndpoint().Relative(), h(o.HandleEndSession))
- router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
+ router.Use(intercept(o.IssuerFromRequest, interceptors...))
+ router.HandleFunc(healthEndpoint, healthHandler)
+ router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
+ router.HandleFunc(oidc.DiscoveryEndpoint, discoveryHandler(o, o.Storage()))
+ router.HandleFunc(o.AuthorizationEndpoint().Relative(), authorizeHandler(o))
+ router.HandleFunc(authCallbackPath(o), AuthorizeCallbackHandler(o))
+ router.HandleFunc(o.TokenEndpoint().Relative(), tokenHandler(o))
+ router.HandleFunc(o.IntrospectionEndpoint().Relative(), introspectionHandler(o))
+ router.HandleFunc(o.UserinfoEndpoint().Relative(), userinfoHandler(o))
+ router.HandleFunc(o.RevocationEndpoint().Relative(), revocationHandler(o))
+ router.HandleFunc(o.EndSessionEndpoint().Relative(), endSessionHandler(o))
+ router.HandleFunc(o.KeysEndpoint().Relative(), keysHandler(o.Storage()))
+ router.HandleFunc(o.DeviceAuthorizationEndpoint().Relative(), DeviceAuthorizationHandler(o))
return router
}
+
+// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
+func AuthCallbackURL(o OpenIDProvider) func(context.Context, string) string {
+ return func(ctx context.Context, requestID string) string {
+ return o.AuthorizationEndpoint().Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
+ }
+}
+
+func authCallbackPath(o OpenIDProvider) string {
+ return o.AuthorizationEndpoint().Relative() + authCallbackPathSuffix
+}
+
+type Config struct {
+ CryptoKey [32]byte
+ DefaultLogoutRedirectURI string
+ CodeMethodS256 bool
+ AuthMethodPost bool
+ AuthMethodPrivateKeyJWT bool
+ GrantTypeRefreshToken bool
+ RequestObjectSupported bool
+ SupportedUILocales []language.Tag
+ SupportedClaims []string
+ SupportedScopes []string
+ DeviceAuthorization DeviceAuthorizationConfig
+ BackChannelLogoutSupported bool
+ BackChannelLogoutSessionSupported bool
+}
+
+// Endpoints defines endpoint routes.
+type Endpoints struct {
+ Authorization *Endpoint
+ Token *Endpoint
+ Introspection *Endpoint
+ Userinfo *Endpoint
+ Revocation *Endpoint
+ EndSession *Endpoint
+ CheckSessionIframe *Endpoint
+ JwksURI *Endpoint
+ DeviceAuthorization *Endpoint
+}
+
+// NewOpenIDProvider creates a provider. The provider provides (with HttpHandler())
+// a http.Router that handles a suite of endpoints (some paths can be overridden):
+//
+// /healthz
+// /ready
+// /.well-known/openid-configuration
+// /oauth/token
+// /oauth/introspect
+// /callback
+// /authorize
+// /userinfo
+// /revoke
+// /end_session
+// /keys
+// /device_authorization
+//
+// This does not include login. Login is handled with a redirect that includes the
+// request ID. The redirect for logins is specified per-client by Client.LoginURL().
+// Successful logins should mark the request as authorized and redirect back to to
+// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
+// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
+//
+// Deprecated: use [NewProvider] with an issuer function direct.
+func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
+ return NewProvider(config, storage, StaticIssuer(issuer), opOpts...)
+}
+
+// NewForwardedOpenIDProvider tries to establishes the issuer from the request Host.
+//
+// Deprecated: use [NewProvider] with an issuer function direct.
+func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
+ return NewProvider(config, storage, IssuerFromHost(path), opOpts...)
+}
+
+// NewForwardedOpenIDProvider tries to establish the Issuer from a Forwarded request header, if it is set.
+// See [IssuerFromForwardedOrHost] for details.
+//
+// Deprecated: use [NewProvider] with an issuer function direct.
+func NewForwardedOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
+ return NewProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...)
+}
+
+// NewProvider creates a provider with a router on it's embedded http.Handler.
+// Issuer is a function that must return the issuer on every request.
+// Typically [StaticIssuer], [IssuerFromHost] or [IssuerFromForwardedOrHost] can be used.
+//
+// The router handles a suite of endpoints (some paths can be overridden):
+//
+// /healthz
+// /ready
+// /.well-known/openid-configuration
+// /oauth/token
+// /oauth/introspect
+// /callback
+// /authorize
+// /userinfo
+// /revoke
+// /end_session
+// /keys
+// /device_authorization
+//
+// This does not include login. Login is handled with a redirect that includes the
+// request ID. The redirect for logins is specified per-client by Client.LoginURL().
+// Successful logins should mark the request as authorized and redirect back to to
+// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
+// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
+func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
+ keySet := &OpenIDKeySet{storage}
+ o := &Provider{
+ config: config,
+ storage: storage,
+ accessTokenKeySet: keySet,
+ idTokenHinKeySet: keySet,
+ endpoints: DefaultEndpoints,
+ timer: make(<-chan time.Time),
+ corsOpts: &defaultCORSOptions,
+ logger: slog.Default(),
+ }
+
+ for _, optFunc := range opOpts {
+ if err := optFunc(o); err != nil {
+ return nil, err
+ }
+ }
+
+ o.issuer, err = issuer(o.insecure)
+ if err != nil {
+ return nil, err
+ }
+ o.Handler = CreateRouter(o, o.interceptors...)
+ o.decoder = schema.NewDecoder()
+ o.decoder.IgnoreUnknownKeys(true)
+ o.encoder = oidc.NewEncoder()
+ o.crypto = NewAESCrypto(config.CryptoKey)
+ return o, nil
+}
+
+type Provider struct {
+ http.Handler
+ config *Config
+ issuer IssuerFromRequest
+ insecure bool
+ endpoints *Endpoints
+ storage Storage
+ accessTokenKeySet oidc.KeySet
+ idTokenHinKeySet oidc.KeySet
+ crypto Crypto
+ decoder *schema.Decoder
+ encoder *schema.Encoder
+ interceptors []HttpInterceptor
+ timer <-chan time.Time
+ accessTokenVerifierOpts []AccessTokenVerifierOpt
+ idTokenHintVerifierOpts []IDTokenHintVerifierOpt
+ corsOpts *cors.Options
+ logger *slog.Logger
+}
+
+func (o *Provider) IssuerFromRequest(r *http.Request) string {
+ return o.issuer(r)
+}
+
+func (o *Provider) Insecure() bool {
+ return o.insecure
+}
+
+func (o *Provider) AuthorizationEndpoint() *Endpoint {
+ return o.endpoints.Authorization
+}
+
+func (o *Provider) TokenEndpoint() *Endpoint {
+ return o.endpoints.Token
+}
+
+func (o *Provider) IntrospectionEndpoint() *Endpoint {
+ return o.endpoints.Introspection
+}
+
+func (o *Provider) UserinfoEndpoint() *Endpoint {
+ return o.endpoints.Userinfo
+}
+
+func (o *Provider) RevocationEndpoint() *Endpoint {
+ return o.endpoints.Revocation
+}
+
+func (o *Provider) EndSessionEndpoint() *Endpoint {
+ return o.endpoints.EndSession
+}
+
+func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
+ return o.endpoints.DeviceAuthorization
+}
+
+func (o *Provider) CheckSessionIframe() *Endpoint {
+ return o.endpoints.CheckSessionIframe
+}
+
+func (o *Provider) KeysEndpoint() *Endpoint {
+ return o.endpoints.JwksURI
+}
+
+func (o *Provider) AuthMethodPostSupported() bool {
+ return o.config.AuthMethodPost
+}
+
+func (o *Provider) CodeMethodS256Supported() bool {
+ return o.config.CodeMethodS256
+}
+
+func (o *Provider) AuthMethodPrivateKeyJWTSupported() bool {
+ return o.config.AuthMethodPrivateKeyJWT
+}
+
+func (o *Provider) TokenEndpointSigningAlgorithmsSupported() []string {
+ return []string{"RS256"}
+}
+
+func (o *Provider) GrantTypeRefreshTokenSupported() bool {
+ return o.config.GrantTypeRefreshToken
+}
+
+func (o *Provider) GrantTypeTokenExchangeSupported() bool {
+ _, ok := o.storage.(TokenExchangeStorage)
+ return ok
+}
+
+func (o *Provider) GrantTypeJWTAuthorizationSupported() bool {
+ return true
+}
+
+func (o *Provider) GrantTypeDeviceCodeSupported() bool {
+ _, ok := o.storage.(DeviceAuthorizationStorage)
+ return ok
+}
+
+func (o *Provider) IntrospectionAuthMethodPrivateKeyJWTSupported() bool {
+ return true
+}
+
+func (o *Provider) IntrospectionEndpointSigningAlgorithmsSupported() []string {
+ return []string{"RS256"}
+}
+
+func (o *Provider) GrantTypeClientCredentialsSupported() bool {
+ _, ok := o.storage.(ClientCredentialsStorage)
+ return ok
+}
+
+func (o *Provider) RevocationAuthMethodPrivateKeyJWTSupported() bool {
+ return true
+}
+
+func (o *Provider) RevocationEndpointSigningAlgorithmsSupported() []string {
+ return []string{"RS256"}
+}
+
+func (o *Provider) RequestObjectSupported() bool {
+ return o.config.RequestObjectSupported
+}
+
+func (o *Provider) RequestObjectSigningAlgorithmsSupported() []string {
+ return []string{"RS256"}
+}
+
+func (o *Provider) SupportedUILocales() []language.Tag {
+ return o.config.SupportedUILocales
+}
+
+func (o *Provider) DeviceAuthorization() DeviceAuthorizationConfig {
+ return o.config.DeviceAuthorization
+}
+
+func (o *Provider) BackChannelLogoutSupported() bool {
+ return o.config.BackChannelLogoutSupported
+}
+
+func (o *Provider) BackChannelLogoutSessionSupported() bool {
+ return o.config.BackChannelLogoutSessionSupported
+}
+
+func (o *Provider) Storage() Storage {
+ return o.storage
+}
+
+func (o *Provider) Decoder() httphelper.Decoder {
+ return o.decoder
+}
+
+func (o *Provider) Encoder() httphelper.Encoder {
+ return o.encoder
+}
+
+func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
+ return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.idTokenHinKeySet, o.idTokenHintVerifierOpts...)
+}
+
+func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
+ return NewJWTProfileVerifier(o.Storage(), IssuerFromContext(ctx), 1*time.Hour, time.Second)
+}
+
+func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
+ return NewAccessTokenVerifier(IssuerFromContext(ctx), o.accessTokenKeySet, o.accessTokenVerifierOpts...)
+}
+
+func (o *Provider) Crypto() Crypto {
+ return o.crypto
+}
+
+func (o *Provider) DefaultLogoutRedirectURI() string {
+ return o.config.DefaultLogoutRedirectURI
+}
+
+func (o *Provider) Probes() []ProbesFn {
+ return []ProbesFn{
+ ReadyStorage(o.Storage()),
+ }
+}
+
+func (o *Provider) CORSOptions() *cors.Options {
+ return o.corsOpts
+}
+
+func (o *Provider) Logger() *slog.Logger {
+ return o.logger
+}
+
+// Deprecated: Provider now implements http.Handler directly.
+func (o *Provider) HttpHandler() http.Handler {
+ return o
+}
+
+type OpenIDKeySet struct {
+ Storage
+}
+
+// VerifySignature implements the oidc.KeySet interface
+// providing an implementation for the keys stored in the OP Storage interface
+func (o *OpenIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
+ keySet, err := o.Storage.KeySet(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("error fetching keys: %w", err)
+ }
+ keyID, alg := oidc.GetKeyIDAndAlg(jws)
+ key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, jsonWebKeySet(keySet).Keys...)
+ if err != nil {
+ return nil, fmt.Errorf("invalid signature: %w", err)
+ }
+ return jws.Verify(&key)
+}
+
+type Option func(o *Provider) error
+
+// WithAllowInsecure allows the use of http (instead of https) for issuers
+// this is not recommended for production use and violates the OIDC specification
+func WithAllowInsecure() Option {
+ return func(o *Provider) error {
+ o.insecure = true
+ return nil
+ }
+}
+
+func WithCustomAuthEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.Authorization = endpoint
+ return nil
+ }
+}
+
+func WithCustomTokenEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.Token = endpoint
+ return nil
+ }
+}
+
+func WithCustomIntrospectionEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.Introspection = endpoint
+ return nil
+ }
+}
+
+func WithCustomUserinfoEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.Userinfo = endpoint
+ return nil
+ }
+}
+
+func WithCustomRevocationEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.Revocation = endpoint
+ return nil
+ }
+}
+
+func WithCustomEndSessionEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.EndSession = endpoint
+ return nil
+ }
+}
+
+func WithCustomKeysEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.JwksURI = endpoint
+ return nil
+ }
+}
+
+func WithCustomDeviceAuthorizationEndpoint(endpoint *Endpoint) Option {
+ return func(o *Provider) error {
+ if err := endpoint.Validate(); err != nil {
+ return err
+ }
+ o.endpoints.DeviceAuthorization = endpoint
+ return nil
+ }
+}
+
+// WithCustomEndpoints sets multiple endpoints at once.
+// Non of the endpoints may be nil, or an error will
+// be returned when the Option used by the Provider.
+func WithCustomEndpoints(auth, token, userInfo, revocation, endSession, keys *Endpoint) Option {
+ return func(o *Provider) error {
+ for _, e := range []*Endpoint{auth, token, userInfo, revocation, endSession, keys} {
+ if err := e.Validate(); err != nil {
+ return err
+ }
+ }
+ o.endpoints.Authorization = auth
+ o.endpoints.Token = token
+ o.endpoints.Userinfo = userInfo
+ o.endpoints.Revocation = revocation
+ o.endpoints.EndSession = endSession
+ o.endpoints.JwksURI = keys
+ return nil
+ }
+}
+
+func WithHttpInterceptors(interceptors ...HttpInterceptor) Option {
+ return func(o *Provider) error {
+ o.interceptors = append(o.interceptors, interceptors...)
+ return nil
+ }
+}
+
+// WithAccessTokenKeySet allows passing a KeySet with public keys for Access Token verification.
+// The default KeySet uses the [Storage] interface
+func WithAccessTokenKeySet(keySet oidc.KeySet) Option {
+ return func(o *Provider) error {
+ o.accessTokenKeySet = keySet
+ return nil
+ }
+}
+
+func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option {
+ return func(o *Provider) error {
+ o.accessTokenVerifierOpts = opts
+ return nil
+ }
+}
+
+// WithIDTokenHintKeySet allows passing a KeySet with public keys for ID Token Hint verification.
+// The default KeySet uses the [Storage] interface.
+func WithIDTokenHintKeySet(keySet oidc.KeySet) Option {
+ return func(o *Provider) error {
+ o.idTokenHinKeySet = keySet
+ return nil
+ }
+}
+
+func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
+ return func(o *Provider) error {
+ o.idTokenHintVerifierOpts = opts
+ return nil
+ }
+}
+
+func WithCORSOptions(opts *cors.Options) Option {
+ return func(o *Provider) error {
+ o.corsOpts = opts
+ return nil
+ }
+}
+
+// WithLogger lets a logger other than slog.Default().
+func WithLogger(logger *slog.Logger) Option {
+ return func(o *Provider) error {
+ o.logger = logger
+ return nil
+ }
+}
+
+func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
+ issuerInterceptor := NewIssuerInterceptor(i)
+ return func(handler http.Handler) http.Handler {
+ for i := len(interceptors) - 1; i >= 0; i-- {
+ handler = interceptors[i](handler)
+ }
+ return issuerInterceptor.Handler(handler)
+ }
+}
diff --git a/pkg/op/op_test.go b/pkg/op/op_test.go
new file mode 100644
index 0000000..e1ac0bd
--- /dev/null
+++ b/pkg/op/op_test.go
@@ -0,0 +1,454 @@
+package op_test
+
+import (
+ "context"
+ "crypto/sha256"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/example/server/storage"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/text/language"
+)
+
+var (
+ testProvider op.OpenIDProvider
+ testConfig = &op.Config{
+ CryptoKey: sha256.Sum256([]byte("test")),
+ DefaultLogoutRedirectURI: pathLoggedOut,
+ CodeMethodS256: true,
+ AuthMethodPost: true,
+ AuthMethodPrivateKeyJWT: true,
+ GrantTypeRefreshToken: true,
+ RequestObjectSupported: true,
+ SupportedClaims: op.DefaultSupportedClaims,
+ SupportedUILocales: []language.Tag{language.English},
+ DeviceAuthorization: op.DeviceAuthorizationConfig{
+ Lifetime: 5 * time.Minute,
+ PollInterval: 5 * time.Second,
+ UserFormPath: "/device",
+ UserCode: op.UserCodeBase20,
+ },
+ }
+)
+
+const (
+ testIssuer = "https://localhost:9998/"
+ pathLoggedOut = "/logged-out"
+)
+
+func init() {
+ storage.RegisterClients(
+ storage.NativeClient("native"),
+ storage.WebClient("web", "secret", "https://example.com"),
+ storage.DeviceClient("device", "secret"),
+ storage.WebClient("api", "secret"),
+ )
+
+ testProvider = newTestProvider(testConfig)
+}
+
+func newTestProvider(config *op.Config) op.OpenIDProvider {
+ storage := storage.NewStorage(storage.NewUserStore(testIssuer))
+ keySet := &op.OpenIDKeySet{storage}
+ provider, err := op.NewOpenIDProvider(testIssuer, config, storage,
+ op.WithAllowInsecure(),
+ op.WithAccessTokenKeySet(keySet),
+ op.WithIDTokenHintKeySet(keySet),
+ )
+ if err != nil {
+ panic(err)
+ }
+ return provider
+}
+
+type routesTestStorage interface {
+ op.Storage
+ AuthRequestDone(id string) error
+}
+
+func mapAsValues(m map[string]string) string {
+ values := make(url.Values, len(m))
+ for k, v := range m {
+ values.Set(k, v)
+ }
+ return values.Encode()
+}
+
+func TestRoutes(t *testing.T) {
+ storage := testProvider.Storage().(routesTestStorage)
+ ctx := op.ContextWithIssuer(context.Background(), testIssuer)
+
+ client, err := storage.GetClientByClientID(ctx, "web")
+ require.NoError(t, err)
+
+ oidcAuthReq := &oidc.AuthRequest{
+ ClientID: client.GetID(),
+ RedirectURI: "https://example.com",
+ MaxAge: gu.Ptr[uint](300),
+ Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
+ ResponseType: oidc.ResponseTypeCode,
+ }
+
+ authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
+ require.NoError(t, err)
+ storage.AuthRequestDone(authReq.GetID())
+
+ accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
+ require.NoError(t, err)
+ accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
+ require.NoError(t, err)
+ idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
+ require.NoError(t, err)
+ jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
+ require.NoError(t, err)
+
+ oidcAuthReq.IDTokenHint = idToken
+
+ serverURL, err := url.Parse(testIssuer)
+ require.NoError(t, err)
+
+ type basicAuth struct {
+ username, password string
+ }
+
+ tests := []struct {
+ name string
+ method string
+ path string
+ basicAuth *basicAuth
+ header map[string]string
+ values map[string]string
+ body map[string]string
+ wantCode int
+ headerContains map[string]string
+ json string // test for exact json output
+ contains []string // when the body output is not constant, we just check for snippets to be present in the response
+ }{
+ {
+ name: "health",
+ method: http.MethodGet,
+ path: "/healthz",
+ wantCode: http.StatusOK,
+ json: `{"status":"ok"}`,
+ },
+ {
+ name: "ready",
+ method: http.MethodGet,
+ path: "/ready",
+ wantCode: http.StatusOK,
+ json: `{"status":"ok"}`,
+ },
+ {
+ name: "discovery",
+ method: http.MethodGet,
+ path: oidc.DiscoveryEndpoint,
+ wantCode: http.StatusOK,
+ json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
+ },
+ {
+ name: "authorization",
+ method: http.MethodGet,
+ path: testProvider.AuthorizationEndpoint().Relative(),
+ values: map[string]string{
+ "client_id": client.GetID(),
+ "redirect_uri": "https://example.com",
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "response_type": string(oidc.ResponseTypeCode),
+ },
+ wantCode: http.StatusFound,
+ headerContains: map[string]string{"Location": "/login/username?authRequestID="},
+ },
+ {
+ name: "authorization callback",
+ method: http.MethodGet,
+ path: testProvider.AuthorizationEndpoint().Relative() + "/callback",
+ values: map[string]string{"id": authReq.GetID()},
+ wantCode: http.StatusFound,
+ headerContains: map[string]string{"Location": "https://example.com?code="},
+ contains: []string{
+ `Found .",
+ },
+ },
+ {
+ // This call will fail. A successful test is already
+ // part of client/integration_test.go
+ name: "code exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeCode),
+ "code": "123",
+ },
+ wantCode: http.StatusUnauthorized,
+ json: `{"error":"invalid_client"}`,
+ },
+ {
+ name: "JWT authorization",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeBearer),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "assertion": jwtToken,
+ },
+ wantCode: http.StatusBadRequest,
+ json: "{\"error\":\"server_error\",\"error_description\":\"audience is not valid: Audience must contain client_id \\\"https://localhost:9998/\\\"\"}",
+ },
+ {
+ name: "Token exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeTokenExchange),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "subject_token": jwtToken,
+ "subject_token_type": string(oidc.AccessTokenType),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"access_token":"`,
+ `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
+ },
+ },
+ {
+ name: "Client credentials exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"sid1", "verysecret"},
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeClientCredentials),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`},
+ },
+ {
+ // This call will fail. A successful test is already
+ // part of device_test.go
+ name: "device token",
+ method: http.MethodPost,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ header: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ body: map[string]string{
+ "grant_type": string(oidc.GrantTypeDeviceCode),
+ "device_code": "123",
+ },
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
+ },
+ {
+ name: "missing grant type",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
+ },
+ {
+ name: "unsupported grant type",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": "foo",
+ },
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
+ },
+ {
+ name: "introspection",
+ method: http.MethodGet,
+ path: testProvider.IntrospectionEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "token": accessToken,
+ },
+ wantCode: http.StatusOK,
+ json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
+ },
+ {
+ name: "user info",
+ method: http.MethodGet,
+ path: testProvider.UserinfoEndpoint().Relative(),
+ header: map[string]string{
+ "authorization": "Bearer " + accessToken,
+ },
+ wantCode: http.StatusOK,
+ json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
+ },
+ {
+ name: "refresh token",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeRefreshToken),
+ "refresh_token": refreshToken,
+ "client_id": client.GetID(),
+ "client_secret": "secret",
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"access_token":"`,
+ `","token_type":"Bearer","refresh_token":"`,
+ `","expires_in":299,"id_token":"`,
+ },
+ },
+ {
+ name: "revoke",
+ method: http.MethodGet,
+ path: testProvider.RevocationEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "token": accessTokenRevoke,
+ },
+ wantCode: http.StatusOK,
+ },
+ {
+ name: "end session",
+ method: http.MethodGet,
+ path: testProvider.EndSessionEndpoint().Relative(),
+ values: map[string]string{
+ "id_token_hint": idToken,
+ "client_id": "web",
+ },
+ wantCode: http.StatusFound,
+ headerContains: map[string]string{"Location": "/logged-out"},
+ contains: []string{`Found .`},
+ },
+ {
+ name: "keys",
+ method: http.MethodGet,
+ path: testProvider.KeysEndpoint().Relative(),
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"keys":[{"use":"sig","kty":"RSA","kid":"`,
+ `","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
+ },
+ },
+ {
+ name: "device authorization",
+ method: http.MethodGet,
+ path: testProvider.DeviceAuthorizationEndpoint().Relative(),
+ basicAuth: &basicAuth{"device", "secret"},
+ values: map[string]string{
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"device_code":"`, `","user_code":"`,
+ `","verification_uri":"https://localhost:9998/device"`,
+ `"verification_uri_complete":"https://localhost:9998/device?user_code=`,
+ `","expires_in":300,"interval":5}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ u := gu.PtrCopy(serverURL)
+ u.Path = tt.path
+ if tt.values != nil {
+ u.RawQuery = mapAsValues(tt.values)
+ }
+ var body io.Reader
+ if tt.body != nil {
+ body = strings.NewReader(mapAsValues(tt.body))
+ }
+
+ req := httptest.NewRequest(tt.method, u.String(), body)
+ for k, v := range tt.header {
+ req.Header.Set(k, v)
+ }
+ if tt.basicAuth != nil {
+ req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
+ }
+
+ rec := httptest.NewRecorder()
+ testProvider.ServeHTTP(rec, req)
+
+ resp := rec.Result()
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantCode, resp.StatusCode)
+
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ respBodyString := string(respBody)
+ t.Log(respBodyString)
+ t.Log(resp.Header)
+
+ if tt.json != "" {
+ assert.JSONEq(t, tt.json, respBodyString)
+ }
+ for _, c := range tt.contains {
+ assert.Contains(t, respBodyString, c)
+ }
+ for k, v := range tt.headerContains {
+ assert.Contains(t, resp.Header.Get(k), v)
+ }
+ })
+ }
+}
+
+func TestWithCustomEndpoints(t *testing.T) {
+ type args struct {
+ auth *op.Endpoint
+ token *op.Endpoint
+ userInfo *op.Endpoint
+ revocation *op.Endpoint
+ endSession *op.Endpoint
+ keys *op.Endpoint
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr error
+ }{
+ {
+ name: "all nil",
+ args: args{},
+ wantErr: op.ErrNilEndpoint,
+ },
+ {
+ name: "all set",
+ args: args{
+ auth: op.NewEndpoint("/authorize"),
+ token: op.NewEndpoint("/oauth/token"),
+ userInfo: op.NewEndpoint("/userinfo"),
+ revocation: op.NewEndpoint("/revoke"),
+ endSession: op.NewEndpoint("/end_session"),
+ keys: op.NewEndpoint("/keys"),
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider, err := op.NewOpenIDProvider(testIssuer, testConfig,
+ storage.NewStorage(storage.NewUserStore(testIssuer)),
+ op.WithCustomEndpoints(tt.args.auth, tt.args.token, tt.args.userInfo, tt.args.revocation, tt.args.endSession, tt.args.keys),
+ )
+ require.ErrorIs(t, err, tt.wantErr)
+ if tt.wantErr != nil {
+ return
+ }
+ assert.Equal(t, tt.args.auth, provider.AuthorizationEndpoint())
+ assert.Equal(t, tt.args.token, provider.TokenEndpoint())
+ assert.Equal(t, tt.args.userInfo, provider.UserinfoEndpoint())
+ assert.Equal(t, tt.args.revocation, provider.RevocationEndpoint())
+ assert.Equal(t, tt.args.endSession, provider.EndSessionEndpoint())
+ assert.Equal(t, tt.args.keys, provider.KeysEndpoint())
+ })
+ }
+}
diff --git a/pkg/op/probes.go b/pkg/op/probes.go
index 50e8a0f..fa713da 100644
--- a/pkg/op/probes.go
+++ b/pkg/op/probes.go
@@ -5,15 +5,21 @@ import (
"errors"
"net/http"
- "github.com/caos/oidc/pkg/utils"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
)
type ProbesFn func(context.Context) error
-func Healthz(w http.ResponseWriter, r *http.Request) {
+func healthHandler(w http.ResponseWriter, r *http.Request) {
ok(w)
}
+func readyHandler(probes []ProbesFn) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Readiness(w, r, probes...)
+ }
+}
+
func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
ctx := r.Context()
for _, probe := range probes {
@@ -25,14 +31,6 @@ func Readiness(w http.ResponseWriter, r *http.Request, probes ...ProbesFn) {
ok(w)
}
-func ReadySigner(s Signer) ProbesFn {
- return func(ctx context.Context) error {
- if s == nil {
- return errors.New("no signer")
- }
- return s.Health(ctx)
- }
-}
func ReadyStorage(s Storage) ProbesFn {
return func(ctx context.Context) error {
if s == nil {
@@ -43,9 +41,9 @@ func ReadyStorage(s Storage) ProbesFn {
}
func ok(w http.ResponseWriter) {
- utils.MarshalJSON(w, status{"ok"})
+ httphelper.MarshalJSON(w, Status{"ok"})
}
-type status struct {
+type Status struct {
Status string `json:"status,omitempty"`
}
diff --git a/pkg/op/server.go b/pkg/op/server.go
new file mode 100644
index 0000000..d45b734
--- /dev/null
+++ b/pkg/op/server.go
@@ -0,0 +1,350 @@
+package op
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/muhlemmer/gu"
+)
+
+// Server describes the interface that needs to be implemented to serve
+// OpenID Connect and Oauth2 standard requests.
+//
+// Methods are called after the HTTP route is resolved and
+// the request body is parsed into the Request's Data field.
+// When a method is called, it can be assumed that required fields,
+// as described in their relevant standard, are validated already.
+// The Response Data field may be of any type to allow flexibility
+// to extend responses with custom fields. There are however requirements
+// in the standards regarding the response models. Where applicable
+// the method documentation gives a recommended type which can be used
+// directly or extended upon.
+//
+// The addition of new methods is not considered a breaking change
+// as defined by semver rules.
+// Implementations MUST embed [UnimplementedServer] to maintain
+// forward compatibility.
+//
+// EXPERIMENTAL: may change until v4
+type Server interface {
+ // Health returns a status of "ok" once the Server is listening.
+ // The recommended Response Data type is [Status].
+ Health(context.Context, *Request[struct{}]) (*Response, error)
+
+ // Ready returns a status of "ok" once all dependencies,
+ // such as database storage, are ready.
+ // An error can be returned to explain what is not ready.
+ // The recommended Response Data type is [Status].
+ Ready(context.Context, *Request[struct{}]) (*Response, error)
+
+ // Discovery returns the OpenID Provider Configuration Information for this server.
+ // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
+ // The recommended Response Data type is [oidc.DiscoveryConfiguration].
+ Discovery(context.Context, *Request[struct{}]) (*Response, error)
+
+ // Keys serves the JWK set which the client can use verify signatures from the op.
+ // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata `jwks_uri` key.
+ // The recommended Response Data type is [jose.JSONWebKeySet].
+ Keys(context.Context, *Request[struct{}]) (*Response, error)
+
+ // VerifyAuthRequest verifies the Auth Request and
+ // adds the Client to the request.
+ //
+ // When the `request` field is populated with a
+ // "Request Object" JWT, it needs to be Validated
+ // and its claims overwrite any fields in the AuthRequest.
+ // If the implementation does not support "Request Object",
+ // it MUST return an [oidc.ErrRequestNotSupported].
+ // https://openid.net/specs/openid-connect-core-1_0.html#RequestObject
+ VerifyAuthRequest(context.Context, *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error)
+
+ // Authorize initiates the authorization flow and redirects to a login page.
+ // See the various https://openid.net/specs/openid-connect-core-1_0.html
+ // authorize endpoint sections (one for each type of flow).
+ Authorize(context.Context, *ClientRequest[oidc.AuthRequest]) (*Redirect, error)
+
+ // DeviceAuthorization initiates the device authorization flow.
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
+ // The recommended Response Data type is [oidc.DeviceAuthorizationResponse].
+ DeviceAuthorization(context.Context, *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error)
+
+ // VerifyClient is called on most oauth/token handlers to authenticate,
+ // using either a secret (POST, Basic) or assertion (JWT).
+ // If no secrets are provided, the client must be public.
+ // This method is called before each method that takes a
+ // [ClientRequest] argument.
+ VerifyClient(context.Context, *Request[ClientCredentials]) (Client, error)
+
+ // CodeExchange returns Tokens after an authorization code
+ // is obtained in a successful Authorize flow.
+ // It is called by the Token endpoint handler when
+ // grant_type has the value authorization_code
+ // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ CodeExchange(context.Context, *ClientRequest[oidc.AccessTokenRequest]) (*Response, error)
+
+ // RefreshToken returns new Tokens after verifying a Refresh token.
+ // It is called by the Token endpoint handler when
+ // grant_type has the value refresh_token
+ // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ RefreshToken(context.Context, *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error)
+
+ // JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant
+ // It is called by the Token endpoint handler when
+ // grant_type has the value urn:ietf:params:oauth:grant-type:jwt-bearer
+ // https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ JWTProfile(context.Context, *Request[oidc.JWTProfileGrantRequest]) (*Response, error)
+
+ // TokenExchange handles the OAuth 2.0 token exchange grant
+ // It is called by the Token endpoint handler when
+ // grant_type has the value urn:ietf:params:oauth:grant-type:token-exchange
+ // https://datatracker.ietf.org/doc/html/rfc8693
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ TokenExchange(context.Context, *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error)
+
+ // ClientCredentialsExchange handles the OAuth 2.0 client credentials grant
+ // It is called by the Token endpoint handler when
+ // grant_type has the value client_credentials
+ // https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ ClientCredentialsExchange(context.Context, *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error)
+
+ // DeviceToken handles the OAuth 2.0 Device Authorization Grant
+ // It is called by the Token endpoint handler when
+ // grant_type has the value urn:ietf:params:oauth:grant-type:device_code.
+ // It is typically called in a polling fashion and appropriate errors
+ // should be returned to signal authorization_pending or access_denied etc.
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.4,
+ // https://datatracker.ietf.org/doc/html/rfc8628#section-3.5.
+ // The recommended Response Data type is [oidc.AccessTokenResponse].
+ DeviceToken(context.Context, *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error)
+
+ // Introspect handles the OAuth 2.0 Token Introspection endpoint.
+ // https://datatracker.ietf.org/doc/html/rfc7662
+ // The recommended Response Data type is [oidc.IntrospectionResponse].
+ Introspect(context.Context, *Request[IntrospectionRequest]) (*Response, error)
+
+ // UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
+ // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
+ // The recommended Response Data type is [oidc.UserInfo].
+ UserInfo(context.Context, *Request[oidc.UserInfoRequest]) (*Response, error)
+
+ // Revocation handles token revocation using an access or refresh token.
+ // https://datatracker.ietf.org/doc/html/rfc7009
+ // There are no response requirements. Data may remain empty.
+ Revocation(context.Context, *ClientRequest[oidc.RevocationRequest]) (*Response, error)
+
+ // EndSession handles the OpenID Connect RP-Initiated Logout.
+ // https://openid.net/specs/openid-connect-rpinitiated-1_0.html
+ // There are no response requirements. Data may remain empty.
+ EndSession(context.Context, *Request[oidc.EndSessionRequest]) (*Redirect, error)
+
+ // mustImpl forces implementations to embed the UnimplementedServer for forward
+ // compatibility with the interface.
+ mustImpl()
+}
+
+// Request contains the [http.Request] informational fields
+// and parsed Data from the request body (POST) or URL parameters (GET).
+// Data can be assumed to be validated according to the applicable
+// standard for the specific endpoints.
+//
+// EXPERIMENTAL: may change until v4
+type Request[T any] struct {
+ Method string
+ URL *url.URL
+ Header http.Header
+ Form url.Values
+ PostForm url.Values
+ Data *T
+}
+
+func (r *Request[_]) path() string {
+ return r.URL.Path
+}
+
+func newRequest[T any](r *http.Request, data *T) *Request[T] {
+ return &Request[T]{
+ Method: r.Method,
+ URL: r.URL,
+ Header: r.Header,
+ Form: r.Form,
+ PostForm: r.PostForm,
+ Data: data,
+ }
+}
+
+// ClientRequest is a Request with a verified client attached to it.
+// Methods that receive this argument may assume the client was authenticated,
+// or verified to be a public client.
+//
+// EXPERIMENTAL: may change until v4
+type ClientRequest[T any] struct {
+ *Request[T]
+ Client Client
+}
+
+func newClientRequest[T any](r *http.Request, data *T, client Client) *ClientRequest[T] {
+ return &ClientRequest[T]{
+ Request: newRequest[T](r, data),
+ Client: client,
+ }
+}
+
+// Response object for most [Server] methods.
+//
+// EXPERIMENTAL: may change until v4
+type Response struct {
+ // Header map will be merged with the
+ // header on the [http.ResponseWriter].
+ Header http.Header
+
+ // Data will be JSON marshaled to
+ // the response body.
+ // We allow any type, so that implementations
+ // can extend the standard types as they wish.
+ // However, each method will recommend which
+ // (base) type to use as model, in order to
+ // be compliant with the standards.
+ Data any
+}
+
+// NewResponse creates a new response for data,
+// without custom headers.
+func NewResponse(data any) *Response {
+ return &Response{
+ Header: make(http.Header),
+ Data: data,
+ }
+}
+
+func (resp *Response) writeOut(w http.ResponseWriter) {
+ gu.MapMerge(resp.Header, w.Header())
+ httphelper.MarshalJSON(w, resp.Data)
+}
+
+// Redirect is a special response type which will
+// initiate a [http.StatusFound] redirect.
+// The Params field will be encoded and set to the
+// URL's RawQuery field before building the URL.
+//
+// EXPERIMENTAL: may change until v4
+type Redirect struct {
+ // Header map will be merged with the
+ // header on the [http.ResponseWriter].
+ Header http.Header
+
+ URL string
+}
+
+func NewRedirect(url string) *Redirect {
+ return &Redirect{
+ Header: make(http.Header),
+ URL: url,
+ }
+}
+
+func (red *Redirect) writeOut(w http.ResponseWriter, r *http.Request) {
+ gu.MapMerge(red.Header, w.Header())
+ http.Redirect(w, r, red.URL, http.StatusFound)
+}
+
+type UnimplementedServer struct{}
+
+// UnimplementedStatusCode is the status code returned for methods
+// that are not yet implemented.
+// Note that this means methods in the sense of the Go interface,
+// and not http methods covered by "501 Not Implemented".
+var UnimplementedStatusCode = http.StatusNotFound
+
+func unimplementedError(r interface{ path() string }) StatusError {
+ err := oidc.ErrServerError().WithDescription("%s not implemented on this server", r.path())
+ return NewStatusError(err, UnimplementedStatusCode)
+}
+
+func unimplementedGrantError(gt oidc.GrantType) StatusError {
+ err := oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", gt)
+ return NewStatusError(err, http.StatusBadRequest) // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
+}
+
+func (UnimplementedServer) mustImpl() {}
+
+func (UnimplementedServer) Health(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
+ if r.Data.RequestParam != "" {
+ return nil, oidc.ErrRequestNotSupported()
+ }
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (*Redirect, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeCode)
+}
+
+func (UnimplementedServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
+}
+
+func (UnimplementedServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeBearer)
+}
+
+func (UnimplementedServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
+}
+
+func (UnimplementedServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
+}
+
+func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
+ return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
+}
+
+func (UnimplementedServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
+ return nil, unimplementedError(r)
+}
+
+func (UnimplementedServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
+ return nil, unimplementedError(r)
+}
diff --git a/pkg/op/server_http.go b/pkg/op/server_http.go
new file mode 100644
index 0000000..d71a354
--- /dev/null
+++ b/pkg/op/server_http.go
@@ -0,0 +1,524 @@
+package op
+
+import (
+ "context"
+ "log/slog"
+ "net/http"
+ "net/url"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/go-chi/chi/v5"
+ "github.com/rs/cors"
+ "github.com/zitadel/logging"
+ "github.com/zitadel/schema"
+)
+
+// RegisterServer registers an implementation of Server.
+// The resulting handler takes care of routing and request parsing,
+// with some basic validation of required fields.
+// The routes can be customized with [WithEndpoints].
+//
+// EXPERIMENTAL: may change until v4
+func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption) http.Handler {
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(true)
+
+ ws := &webServer{
+ router: chi.NewRouter(),
+ server: server,
+ endpoints: endpoints,
+ decoder: decoder,
+ corsOpts: &defaultCORSOptions,
+ logger: slog.Default(),
+ }
+
+ for _, option := range options {
+ option(ws)
+ }
+
+ ws.createRouter()
+ ws.handler = ws.router
+ if ws.corsOpts != nil {
+ ws.handler = cors.New(*ws.corsOpts).Handler(ws.router)
+ }
+ return ws
+}
+
+type ServerOption func(s *webServer)
+
+// WithHTTPMiddleware sets the passed middleware chain to the root of
+// the Server's router.
+func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
+ return func(s *webServer) {
+ s.router.Use(m...)
+ }
+}
+
+// WithSetRouter allows customization or the Server's router.
+func WithSetRouter(set func(chi.Router)) ServerOption {
+ return func(s *webServer) {
+ set(s.router)
+ }
+}
+
+// WithDecoder overrides the default decoder,
+// which is a [schema.Decoder] with IgnoreUnknownKeys set to true.
+func WithDecoder(decoder httphelper.Decoder) ServerOption {
+ return func(s *webServer) {
+ s.decoder = decoder
+ }
+}
+
+// WithServerCORSOptions sets the CORS policy for the Server's router.
+func WithServerCORSOptions(opts *cors.Options) ServerOption {
+ return func(s *webServer) {
+ s.corsOpts = opts
+ }
+}
+
+// WithFallbackLogger overrides the fallback logger, which
+// is used when no logger was found in the context.
+// Defaults to [slog.Default].
+func WithFallbackLogger(logger *slog.Logger) ServerOption {
+ return func(s *webServer) {
+ s.logger = logger
+ }
+}
+
+type webServer struct {
+ server Server
+ router *chi.Mux
+ handler http.Handler
+ endpoints Endpoints
+ decoder httphelper.Decoder
+ corsOpts *cors.Options
+ logger *slog.Logger
+}
+
+func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ s.handler.ServeHTTP(w, r)
+}
+
+func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
+ if logger, ok := logging.FromContext(ctx); ok {
+ return logger
+ }
+ return s.logger
+}
+
+func (s *webServer) createRouter() {
+ s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
+ s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
+ s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
+
+ s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
+ s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
+ s.endpointRoute(s.endpoints.Token, s.tokensHandler)
+ s.endpointRoute(s.endpoints.Introspection, s.introspectionHandler)
+ s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
+ s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
+ s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
+ s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
+}
+
+func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
+ if e != nil {
+ traceHandler := func(w http.ResponseWriter, r *http.Request) {
+ ctx, span := tracer.Start(r.Context(), e.Relative())
+ r = r.WithContext(ctx)
+ hf(w, r)
+ defer span.End()
+ }
+ s.router.HandleFunc(e.Relative(), traceHandler)
+ s.logger.Info("registered route", "endpoint", e.Relative())
+ }
+}
+
+type clientHandler func(w http.ResponseWriter, r *http.Request, client Client)
+
+func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx, span := tracer.Start(r.Context(), r.URL.Path)
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ client, err := s.verifyRequestClient(r)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType != "" {
+ if !ValidateGrantType(client, grantType) {
+ WriteError(w, r, oidc.ErrUnauthorizedClient().WithDescription("grant_type %q not allowed", grantType), s.getLogger(r.Context()))
+ return
+ }
+ }
+ handler(w, r, client)
+ }
+}
+
+func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
+ cc, err := s.parseClientCredentials(r)
+ if err != nil {
+ return nil, err
+ }
+ return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
+ Method: r.Method,
+ URL: r.URL,
+ Header: r.Header,
+ Form: r.Form,
+ Data: cc,
+ })
+}
+
+func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredentials, err error) {
+ if err := r.ParseForm(); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+ cc := new(ClientCredentials)
+ if err = s.decoder.Decode(cc, r.Form); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+ // Basic auth takes precedence, so if set it overwrites the form data.
+ if clientID, clientSecret, ok := r.BasicAuth(); ok {
+ cc.ClientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ cc.ClientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ }
+ if cc.ClientID == "" && cc.ClientAssertion == "" {
+ return nil, oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided")
+ }
+ if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
+ return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
+ }
+ return cc, nil
+}
+
+func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
+ request, err := decodeRequest[oidc.AuthRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ redirect, err := s.authorize(r.Context(), newRequest(r, request))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ redirect.writeOut(w, r)
+}
+
+func (s *webServer) authorize(ctx context.Context, r *Request[oidc.AuthRequest]) (_ *Redirect, err error) {
+ cr, err := s.server.VerifyAuthRequest(ctx, r)
+ if err != nil {
+ return nil, err
+ }
+ authReq := cr.Data
+ if authReq.RedirectURI == "" {
+ return nil, ErrAuthReqMissingRedirectURI
+ }
+ authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
+ if err != nil {
+ return nil, err
+ }
+ authReq.Scopes, err = ValidateAuthReqScopes(cr.Client, authReq.Scopes)
+ if err != nil {
+ return nil, err
+ }
+ if err := ValidateAuthReqRedirectURI(cr.Client, authReq.RedirectURI, authReq.ResponseType); err != nil {
+ return nil, err
+ }
+ if err := ValidateAuthReqResponseType(cr.Client, authReq.ResponseType); err != nil {
+ return nil, err
+ }
+ return s.server.Authorize(ctx, cr)
+}
+
+func (s *webServer) deviceAuthorizationHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.DeviceAuthorizationRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.DeviceAuthorization(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) tokensHandler(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
+ return
+ }
+
+ switch grantType := oidc.GrantType(r.Form.Get("grant_type")); grantType {
+ case oidc.GrantTypeCode:
+ s.withClient(s.codeExchangeHandler)(w, r)
+ case oidc.GrantTypeRefreshToken:
+ s.withClient(s.refreshTokenHandler)(w, r)
+ case oidc.GrantTypeClientCredentials:
+ s.withClient(s.clientCredentialsHandler)(w, r)
+ case oidc.GrantTypeBearer:
+ s.jwtProfileHandler(w, r)
+ case oidc.GrantTypeTokenExchange:
+ s.withClient(s.tokenExchangeHandler)(w, r)
+ case oidc.GrantTypeDeviceCode:
+ s.withClient(s.deviceTokenHandler)(w, r)
+ case "":
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), s.getLogger(r.Context()))
+ default:
+ WriteError(w, r, unimplementedGrantError(grantType), s.getLogger(r.Context()))
+ }
+}
+
+func (s *webServer) jwtProfileHandler(w http.ResponseWriter, r *http.Request) {
+ request, err := decodeRequest[oidc.JWTProfileGrantRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.Assertion == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("assertion missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.JWTProfile(r.Context(), newRequest(r, request))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) codeExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.AccessTokenRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.Code == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), s.getLogger(r.Context()))
+ return
+ }
+ if request.RedirectURI == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("redirect_uri missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.CodeExchange(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) refreshTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.RefreshTokenRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.RefreshToken == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("refresh_token missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.RefreshToken(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) tokenExchangeHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.TokenExchangeRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.SubjectToken == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token missing"), s.getLogger(r.Context()))
+ return
+ }
+ if request.SubjectTokenType == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing"), s.getLogger(r.Context()))
+ return
+ }
+ if !request.SubjectTokenType.IsSupported() {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported"), s.getLogger(r.Context()))
+ return
+ }
+ if request.RequestedTokenType != "" && !request.RequestedTokenType.IsSupported() {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported"), s.getLogger(r.Context()))
+ return
+ }
+ if request.ActorTokenType != "" && !request.ActorTokenType.IsSupported() {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.TokenExchange(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) clientCredentialsHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ if client.AuthMethod() == oidc.AuthMethodNone {
+ WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
+ return
+ }
+
+ request, err := decodeRequest[oidc.ClientCredentialsRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.ClientCredentialsExchange(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.DeviceAccessTokenRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.DeviceCode == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("device_code missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.DeviceToken(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) {
+ cc, err := s.parseClientCredentials(r)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if cc.ClientSecret == "" && cc.ClientAssertion == "" {
+ WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
+ return
+ }
+ request, err := decodeRequest[oidc.IntrospectionRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.Token == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.Introspect(r.Context(), newRequest(r, &IntrospectionRequest{cc, request}))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) userInfoHandler(w http.ResponseWriter, r *http.Request) {
+ request, err := decodeRequest[oidc.UserInfoRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if token, err := getAccessToken(r); err == nil {
+ request.AccessToken = token
+ }
+ if request.AccessToken == "" {
+ err = NewStatusError(
+ oidc.ErrInvalidRequest().WithDescription("access token missing"),
+ http.StatusUnauthorized,
+ )
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.UserInfo(r.Context(), newRequest(r, request))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) revocationHandler(w http.ResponseWriter, r *http.Request, client Client) {
+ request, err := decodeRequest[oidc.RevocationRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ if request.Token == "" {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.Revocation(r.Context(), newClientRequest(r, request, client))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+}
+
+func (s *webServer) endSessionHandler(w http.ResponseWriter, r *http.Request) {
+ request, err := decodeRequest[oidc.EndSessionRequest](s.decoder, r, false)
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp, err := s.server.EndSession(r.Context(), newRequest(r, request))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w, r)
+}
+
+func simpleHandler(s *webServer, method func(context.Context, *Request[struct{}]) (*Response, error)) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err), s.getLogger(r.Context()))
+ return
+ }
+ resp, err := method(r.Context(), newRequest(r, &struct{}{}))
+ if err != nil {
+ WriteError(w, r, err, s.getLogger(r.Context()))
+ return
+ }
+ resp.writeOut(w)
+ }
+}
+
+func decodeRequest[R any](decoder httphelper.Decoder, r *http.Request, postOnly bool) (*R, error) {
+ dst := new(R)
+ if err := r.ParseForm(); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+ form := r.Form
+ if postOnly {
+ form = r.PostForm
+ }
+ if err := decoder.Decode(dst, form); err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+ return dst, nil
+}
diff --git a/pkg/op/server_http_routes_test.go b/pkg/op/server_http_routes_test.go
new file mode 100644
index 0000000..02200ee
--- /dev/null
+++ b/pkg/op/server_http_routes_test.go
@@ -0,0 +1,345 @@
+package op_test
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/client"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+func jwtProfile() (string, error) {
+ keyData, err := client.ConfigFromKeyFile("../../example/server/service-key1.json")
+ if err != nil {
+ return "", err
+ }
+ signer, err := client.NewSignerFromPrivateKeyByte([]byte(keyData.Key), keyData.KeyID)
+ if err != nil {
+ return "", err
+ }
+ return client.SignedJWTProfileAssertion(keyData.UserID, []string{testIssuer}, time.Hour, signer)
+}
+
+func TestServerRoutes(t *testing.T) {
+ server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints), op.AuthorizeCallbackHandler(testProvider))
+
+ storage := testProvider.Storage().(routesTestStorage)
+ ctx := op.ContextWithIssuer(context.Background(), testIssuer)
+
+ client, err := storage.GetClientByClientID(ctx, "web")
+ require.NoError(t, err)
+
+ oidcAuthReq := &oidc.AuthRequest{
+ ClientID: client.GetID(),
+ RedirectURI: "https://example.com",
+ MaxAge: gu.Ptr[uint](300),
+ Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
+ ResponseType: oidc.ResponseTypeCode,
+ }
+
+ authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
+ require.NoError(t, err)
+ storage.AuthRequestDone(authReq.GetID())
+
+ accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
+ require.NoError(t, err)
+ accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
+ require.NoError(t, err)
+ idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
+ require.NoError(t, err)
+ jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
+ require.NoError(t, err)
+ jwtProfileToken, err := jwtProfile()
+ require.NoError(t, err)
+
+ oidcAuthReq.IDTokenHint = idToken
+
+ serverURL, err := url.Parse(testIssuer)
+ require.NoError(t, err)
+
+ type basicAuth struct {
+ username, password string
+ }
+
+ tests := []struct {
+ name string
+ method string
+ path string
+ basicAuth *basicAuth
+ header map[string]string
+ values map[string]string
+ body map[string]string
+ wantCode int
+ headerContains map[string]string
+ json string // test for exact json output
+ contains []string // when the body output is not constant, we just check for snippets to be present in the response
+ }{
+ {
+ name: "health",
+ method: http.MethodGet,
+ path: "/healthz",
+ wantCode: http.StatusOK,
+ json: `{"status":"ok"}`,
+ },
+ {
+ name: "ready",
+ method: http.MethodGet,
+ path: "/ready",
+ wantCode: http.StatusOK,
+ json: `{"status":"ok"}`,
+ },
+ {
+ name: "discovery",
+ method: http.MethodGet,
+ path: oidc.DiscoveryEndpoint,
+ wantCode: http.StatusOK,
+ json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
+ },
+ {
+ name: "authorization",
+ method: http.MethodGet,
+ path: testProvider.AuthorizationEndpoint().Relative(),
+ values: map[string]string{
+ "client_id": client.GetID(),
+ "redirect_uri": "https://example.com",
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "response_type": string(oidc.ResponseTypeCode),
+ },
+ wantCode: http.StatusFound,
+ headerContains: map[string]string{"Location": "/login/username?authRequestID="},
+ },
+ {
+ // This call will fail. A successfull test is already
+ // part of client/integration_test.go
+ name: "code exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeCode),
+ "client_id": client.GetID(),
+ "client_secret": "secret",
+ "redirect_uri": "https://example.com",
+ "code": "123",
+ },
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
+ },
+ {
+ name: "JWT authorization",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeBearer),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "assertion": jwtProfileToken,
+ },
+ wantCode: http.StatusOK,
+ contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299,"scope":"openid"}`},
+ },
+ {
+ name: "Token exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeTokenExchange),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ "subject_token": jwtToken,
+ "subject_token_type": string(oidc.AccessTokenType),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"access_token":"`,
+ `","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
+ },
+ },
+ {
+ name: "Client credentials exchange",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"sid1", "verysecret"},
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeClientCredentials),
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`},
+ },
+ {
+ // This call will fail. A successful test is already
+ // part of device_test.go
+ name: "device token",
+ method: http.MethodPost,
+ path: testProvider.TokenEndpoint().Relative(),
+ basicAuth: &basicAuth{"device", "secret"},
+ header: map[string]string{
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ body: map[string]string{
+ "grant_type": string(oidc.GrantTypeDeviceCode),
+ "device_code": "123",
+ },
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
+ },
+ {
+ name: "missing grant type",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
+ },
+ {
+ name: "unsupported grant type",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": "foo",
+ },
+ wantCode: http.StatusBadRequest,
+ json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
+ },
+ {
+ name: "introspection",
+ method: http.MethodGet,
+ path: testProvider.IntrospectionEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "token": accessToken,
+ },
+ wantCode: http.StatusOK,
+ json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
+ },
+ {
+ name: "user info",
+ method: http.MethodGet,
+ path: testProvider.UserinfoEndpoint().Relative(),
+ header: map[string]string{
+ "authorization": "Bearer " + accessToken,
+ },
+ wantCode: http.StatusOK,
+ json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
+ },
+ {
+ name: "refresh token",
+ method: http.MethodGet,
+ path: testProvider.TokenEndpoint().Relative(),
+ values: map[string]string{
+ "grant_type": string(oidc.GrantTypeRefreshToken),
+ "refresh_token": refreshToken,
+ "client_id": client.GetID(),
+ "client_secret": "secret",
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"access_token":"`,
+ `","token_type":"Bearer","refresh_token":"`,
+ `","expires_in":299,"id_token":"`,
+ },
+ },
+ {
+ name: "revoke",
+ method: http.MethodGet,
+ path: testProvider.RevocationEndpoint().Relative(),
+ basicAuth: &basicAuth{"web", "secret"},
+ values: map[string]string{
+ "token": accessTokenRevoke,
+ },
+ wantCode: http.StatusOK,
+ },
+ {
+ name: "end session",
+ method: http.MethodGet,
+ path: testProvider.EndSessionEndpoint().Relative(),
+ values: map[string]string{
+ "id_token_hint": idToken,
+ "client_id": "web",
+ },
+ wantCode: http.StatusFound,
+ headerContains: map[string]string{"Location": "/logged-out"},
+ contains: []string{`Found .`},
+ },
+ {
+ name: "keys",
+ method: http.MethodGet,
+ path: testProvider.KeysEndpoint().Relative(),
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"keys":[{"use":"sig","kty":"RSA","kid":"`,
+ `","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
+ },
+ },
+ {
+ name: "device authorization",
+ method: http.MethodGet,
+ path: testProvider.DeviceAuthorizationEndpoint().Relative(),
+ basicAuth: &basicAuth{"device", "secret"},
+ values: map[string]string{
+ "scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
+ },
+ wantCode: http.StatusOK,
+ contains: []string{
+ `{"device_code":"`, `","user_code":"`,
+ `","verification_uri":"https://localhost:9998/device"`,
+ `"verification_uri_complete":"https://localhost:9998/device?user_code=`,
+ `","expires_in":300,"interval":5}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ u := gu.PtrCopy(serverURL)
+ u.Path = tt.path
+ if tt.values != nil {
+ u.RawQuery = mapAsValues(tt.values)
+ }
+ var body io.Reader
+ if tt.body != nil {
+ body = strings.NewReader(mapAsValues(tt.body))
+ }
+
+ req := httptest.NewRequest(tt.method, u.String(), body)
+ for k, v := range tt.header {
+ req.Header.Set(k, v)
+ }
+ if tt.basicAuth != nil {
+ req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
+ }
+
+ rec := httptest.NewRecorder()
+ server.ServeHTTP(rec, req)
+
+ resp := rec.Result()
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantCode, resp.StatusCode)
+
+ respBody, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ respBodyString := string(respBody)
+ t.Log(respBodyString)
+ t.Log(resp.Header)
+
+ if tt.json != "" {
+ assert.JSONEq(t, tt.json, respBodyString)
+ }
+ for _, c := range tt.contains {
+ assert.Contains(t, respBodyString, c)
+ }
+ for k, v := range tt.headerContains {
+ assert.Contains(t, resp.Header.Get(k), v)
+ }
+ })
+ }
+}
diff --git a/pkg/op/server_http_test.go b/pkg/op/server_http_test.go
new file mode 100644
index 0000000..75d02ca
--- /dev/null
+++ b/pkg/op/server_http_test.go
@@ -0,0 +1,1328 @@
+package op
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/zitadel/schema"
+)
+
+func TestRegisterServer(t *testing.T) {
+ server := UnimplementedServer{}
+ endpoints := Endpoints{
+ Authorization: &Endpoint{
+ path: "/auth",
+ },
+ }
+ decoder := schema.NewDecoder()
+ logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+
+ h := RegisterServer(server, endpoints,
+ WithDecoder(decoder),
+ WithFallbackLogger(logger),
+ )
+ got := h.(*webServer)
+ assert.Equal(t, got.server, server)
+ assert.Equal(t, got.endpoints, endpoints)
+ assert.Equal(t, got.decoder, decoder)
+ assert.Equal(t, got.logger, logger)
+}
+
+type testClient struct {
+ id string
+ appType ApplicationType
+ authMethod oidc.AuthMethod
+ accessTokenType AccessTokenType
+ responseTypes []oidc.ResponseType
+ grantTypes []oidc.GrantType
+ devMode bool
+}
+
+type clientType string
+
+const (
+ clientTypeWeb clientType = "web"
+ clientTypeNative clientType = "native"
+ clientTypeUserAgent clientType = "useragent"
+)
+
+func newClient(kind clientType) *testClient {
+ client := &testClient{
+ id: string(kind),
+ }
+
+ switch kind {
+ case clientTypeWeb:
+ client.appType = ApplicationTypeWeb
+ client.authMethod = oidc.AuthMethodBasic
+ client.accessTokenType = AccessTokenTypeBearer
+ client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
+ case clientTypeNative:
+ client.appType = ApplicationTypeNative
+ client.authMethod = oidc.AuthMethodNone
+ client.accessTokenType = AccessTokenTypeBearer
+ client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeCode}
+ case clientTypeUserAgent:
+ client.appType = ApplicationTypeUserAgent
+ client.authMethod = oidc.AuthMethodBasic
+ client.accessTokenType = AccessTokenTypeJWT
+ client.responseTypes = []oidc.ResponseType{oidc.ResponseTypeIDToken}
+ default:
+ panic(fmt.Errorf("invalid client type %s", kind))
+ }
+ return client
+}
+
+func (c *testClient) RedirectURIs() []string {
+ return []string{
+ "https://registered.com/callback",
+ "http://registered.com/callback",
+ "http://localhost:9999/callback",
+ "custom://callback",
+ }
+}
+
+func (c *testClient) PostLogoutRedirectURIs() []string {
+ return []string{}
+}
+
+func (c *testClient) LoginURL(id string) string {
+ return "login?id=" + id
+}
+
+func (c *testClient) ApplicationType() ApplicationType {
+ return c.appType
+}
+
+func (c *testClient) AuthMethod() oidc.AuthMethod {
+ return c.authMethod
+}
+
+func (c *testClient) GetID() string {
+ return c.id
+}
+
+func (c *testClient) AccessTokenLifetime() time.Duration {
+ return 5 * time.Minute
+}
+
+func (c *testClient) IDTokenLifetime() time.Duration {
+ return 5 * time.Minute
+}
+
+func (c *testClient) AccessTokenType() AccessTokenType {
+ return c.accessTokenType
+}
+
+func (c *testClient) ResponseTypes() []oidc.ResponseType {
+ return c.responseTypes
+}
+
+func (c *testClient) GrantTypes() []oidc.GrantType {
+ return c.grantTypes
+}
+
+func (c *testClient) DevMode() bool {
+ return c.devMode
+}
+
+func (c *testClient) AllowedScopes() []string {
+ return nil
+}
+
+func (c *testClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+func (c *testClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+func (c *testClient) IsScopeAllowed(scope string) bool {
+ return false
+}
+
+func (c *testClient) IDTokenUserinfoClaimsAssertion() bool {
+ return false
+}
+
+func (c *testClient) ClockSkew() time.Duration {
+ return 0
+}
+
+type requestVerifier struct {
+ UnimplementedServer
+ client Client
+}
+
+func (s *requestVerifier) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
+ if s.client == nil {
+ return nil, oidc.ErrServerError()
+ }
+ return &ClientRequest[oidc.AuthRequest]{
+ Request: r,
+ Client: s.client,
+ }, nil
+}
+
+func (s *requestVerifier) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
+ if s.client == nil {
+ return nil, oidc.ErrServerError()
+ }
+ return s.client, nil
+}
+
+var testDecoder = func() *schema.Decoder {
+ decoder := schema.NewDecoder()
+ decoder.IgnoreUnknownKeys(true)
+ return decoder
+}()
+
+type webServerResult struct {
+ wantStatus int
+ wantBody string
+}
+
+func runWebServerTest(t *testing.T, handler http.HandlerFunc, r *http.Request, want webServerResult) {
+ t.Helper()
+ if r.Method == http.MethodPost {
+ r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ }
+ w := httptest.NewRecorder()
+ handler(w, r)
+ res := w.Result()
+ assert.Equal(t, want.wantStatus, res.StatusCode)
+ body, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+ assert.JSONEq(t, want.wantBody, string(body))
+}
+
+func Test_webServer_withClient(t *testing.T) {
+ tests := []struct {
+ name string
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "parse error",
+ r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
+ },
+ },
+ {
+ name: "invalid grant type",
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&grant_type=bad&foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unauthorized_client", "error_description":"grant_type \"bad\" not allowed"}`,
+ },
+ },
+ {
+ name: "no grant type",
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native&foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusOK,
+ wantBody: `{"foo":"bar"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: &requestVerifier{
+ client: newClient(clientTypeNative),
+ },
+ decoder: testDecoder,
+ logger: slog.Default(),
+ }
+ handler := func(w http.ResponseWriter, r *http.Request, client Client) {
+ fmt.Fprintf(w, `{"foo":%q}`, r.FormValue("foo"))
+ }
+ runWebServerTest(t, s.withClient(handler), tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_verifyRequestClient(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want Client
+ wantErr error
+ }{
+ {
+ name: "parse form error",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
+ wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"),
+ },
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"),
+ },
+ {
+ name: "basic auth, client_id error",
+ decoder: testDecoder,
+ r: func() *http.Request {
+ r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
+ r.SetBasicAuth(`%%%`, "secret")
+ return r
+ }(),
+ wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"),
+ },
+ {
+ name: "basic auth, client_secret error",
+ decoder: testDecoder,
+ r: func() *http.Request {
+ r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
+ r.SetBasicAuth("web", `%%%`)
+ return r
+ }(),
+ wantErr: oidc.ErrInvalidClient().WithDescription("invalid basic auth header"),
+ },
+ {
+ name: "missing client id and assertion",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ wantErr: oidc.ErrInvalidRequest().WithDescription("client_id or client_assertion must be provided"),
+ },
+ {
+ name: "wrong assertion type",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_assertion=xxx&client_assertion_type=wrong")),
+ wantErr: oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type wrong"),
+ },
+ {
+ name: "unimplemented verify client called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar&client_id=web")),
+ wantErr: StatusError{
+ parent: oidc.ErrServerError().WithDescription("/ not implemented on this server"),
+ statusCode: UnimplementedStatusCode,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ tt.r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ got, err := s.verifyRequestClient(tt.r)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_webServer_authorizeHandler(t *testing.T) {
+ type fields struct {
+ server Server
+ decoder httphelper.Decoder
+ }
+ tests := []struct {
+ name string
+ fields fields
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ fields: fields{
+ server: &requestVerifier{},
+ decoder: schema.NewDecoder(),
+ },
+ r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "server error",
+ fields: fields{
+ server: &requestVerifier{},
+ decoder: testDecoder,
+ },
+ r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusInternalServerError,
+ wantBody: `{"error":"server_error"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: tt.fields.server,
+ decoder: tt.fields.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.authorizeHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_authorize(t *testing.T) {
+ type args struct {
+ ctx context.Context
+ r *Request[oidc.AuthRequest]
+ }
+ tests := []struct {
+ name string
+ server Server
+ args args
+ want *Redirect
+ wantErr error
+ }{
+ {
+ name: "verify error",
+ server: &requestVerifier{},
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ RedirectURI: "https://registered.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ },
+ },
+ },
+ wantErr: oidc.ErrServerError(),
+ },
+ {
+ name: "missing redirect",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ MaxAge: gu.Ptr[uint](300),
+ },
+ },
+ },
+ wantErr: ErrAuthReqMissingRedirectURI,
+ },
+ {
+ name: "invalid prompt",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ RedirectURI: "https://registered.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ Prompt: []string{oidc.PromptNone, oidc.PromptLogin},
+ },
+ },
+ },
+ wantErr: oidc.ErrInvalidRequest().WithDescription("The prompt parameter `none` must only be used as a single value"),
+ },
+ {
+ name: "missing scopes",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ RedirectURI: "https://registered.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ Prompt: []string{oidc.PromptNone},
+ },
+ },
+ },
+ wantErr: oidc.ErrInvalidRequest().
+ WithDescription("The scope of your request is missing. Please ensure some scopes are requested. " +
+ "If you have any questions, you may contact the administrator of the application."),
+ },
+ {
+ name: "invalid redirect",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ RedirectURI: "https://example.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ Prompt: []string{oidc.PromptNone},
+ },
+ },
+ },
+ wantErr: oidc.ErrInvalidRequestRedirectURI().
+ WithDescription("The requested redirect_uri is missing in the client configuration. " +
+ "If you have any questions, you may contact the administrator of the application."),
+ },
+ {
+ name: "invalid response type",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeIDToken,
+ ClientID: "web",
+ RedirectURI: "https://registered.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ Prompt: []string{oidc.PromptNone},
+ },
+ },
+ },
+ wantErr: oidc.ErrUnauthorizedClient().WithDescription("The requested response type is missing in the client configuration. " +
+ "If you have any questions, you may contact the administrator of the application."),
+ },
+ {
+ name: "unimplemented Authorize called",
+ server: &requestVerifier{
+ client: newClient(clientTypeWeb),
+ },
+ args: args{
+ ctx: context.Background(),
+ r: &Request[oidc.AuthRequest]{
+ URL: &url.URL{
+ Path: "/authorize",
+ },
+ Data: &oidc.AuthRequest{
+ Scopes: oidc.SpaceDelimitedArray{"openid"},
+ ResponseType: oidc.ResponseTypeCode,
+ ClientID: "web",
+ RedirectURI: "https://registered.com/callback",
+ MaxAge: gu.Ptr[uint](300),
+ Prompt: []string{oidc.PromptNone},
+ },
+ },
+ },
+ wantErr: StatusError{
+ parent: oidc.ErrServerError().WithDescription("/authorize not implemented on this server"),
+ statusCode: UnimplementedStatusCode,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: tt.server,
+ decoder: testDecoder,
+ logger: slog.Default(),
+ }
+ got, err := s.authorize(tt.args.ctx, tt.args.r)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_webServer_deviceAuthorizationHandler(t *testing.T) {
+ type fields struct {
+ server Server
+ decoder httphelper.Decoder
+ }
+ tests := []struct {
+ name string
+ fields fields
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ fields: fields{
+ server: &requestVerifier{},
+ decoder: schema.NewDecoder(),
+ },
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "unimplemented DeviceAuthorization called",
+ fields: fields{
+ server: &requestVerifier{
+ client: newClient(clientTypeNative),
+ },
+ decoder: testDecoder,
+ },
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=native_client")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: tt.fields.server,
+ decoder: tt.fields.decoder,
+ logger: slog.Default(),
+ }
+ client := newClient(clientTypeUserAgent)
+ runWebServerClientTest(t, s.deviceAuthorizationHandler, tt.r, client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_tokensHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "parse form error",
+ r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
+ },
+ },
+ {
+ name: "missing grant type",
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"grant_type missing"}`,
+ },
+ },
+ {
+ name: "invalid grant type",
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("grant_type=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"bar not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.tokensHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_jwtProfileHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "assertion missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"assertion missing"}`,
+ },
+ },
+ {
+ name: "unimplemented JWTProfile called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("assertion=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:jwt-bearer not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.jwtProfileHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func runWebServerClientTest(t *testing.T, handler func(http.ResponseWriter, *http.Request, Client), r *http.Request, client Client, want webServerResult) {
+ t.Helper()
+ runWebServerTest(t, func(client Client) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ handler(w, r, client)
+ }
+ }(client), r, want)
+}
+
+func Test_webServer_codeExchangeHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "code missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"code missing"}`,
+ },
+ },
+ {
+ name: "redirect missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"redirect_uri missing"}`,
+ },
+ },
+ {
+ name: "unimplemented CodeExchange called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("code=123&redirect_uri=https://example.com/callback")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"authorization_code not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ client := newClient(clientTypeUserAgent)
+ runWebServerClientTest(t, s.codeExchangeHandler, tt.r, client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_refreshTokenHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "refresh token missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"refresh_token missing"}`,
+ },
+ },
+ {
+ name: "unimplemented RefreshToken called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("refresh_token=xxx")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"refresh_token not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ client := newClient(clientTypeUserAgent)
+ runWebServerClientTest(t, s.refreshTokenHandler, tt.r, client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_tokenExchangeHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "subject token missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"subject_token missing"}`,
+ },
+ },
+ {
+ name: "subject token type missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"subject_token_type missing"}`,
+ },
+ },
+ {
+ name: "subject token type unsupported",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=foo")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"subject_token_type is not supported"}`,
+ },
+ },
+ {
+ name: "unsupported requested token type",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=foo")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"requested_token_type is not supported"}`,
+ },
+ },
+ {
+ name: "unsupported actor token type",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=foo")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"actor_token_type is not supported"}`,
+ },
+ },
+ {
+ name: "unimplemented TokenExchange called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("subject_token=xxx&subject_token_type=urn:ietf:params:oauth:token-type:access_token&requested_token_type=urn:ietf:params:oauth:token-type:access_token&actor_token_type=urn:ietf:params:oauth:token-type:access_token")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:token-exchange not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ client := newClient(clientTypeUserAgent)
+ runWebServerClientTest(t, s.tokenExchangeHandler, tt.r, client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_clientCredentialsHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ client Client
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ client: newClient(clientTypeUserAgent),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "public client",
+ decoder: testDecoder,
+ client: newClient(clientTypeNative),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`,
+ },
+ },
+ {
+ name: "unimplemented ClientCredentialsExchange called",
+ decoder: testDecoder,
+ client: newClient(clientTypeUserAgent),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"client_credentials not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerClientTest(t, s.clientCredentialsHandler, tt.r, tt.client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_deviceTokenHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "device code missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"device_code missing"}`,
+ },
+ },
+ {
+ name: "unimplemented DeviceToken called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("device_code=xxx")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"unsupported_grant_type", "error_description":"urn:ietf:params:oauth:grant-type:device_code not supported"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ client := newClient(clientTypeUserAgent)
+ runWebServerClientTest(t, s.deviceTokenHandler, tt.r, client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_introspectionHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "public client",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`,
+ },
+ },
+ {
+ name: "token missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"token missing"}`,
+ },
+ },
+ {
+ name: "unimplemented Introspect called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET&token=xxx")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.introspectionHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_userInfoHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "access token missing",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusUnauthorized,
+ wantBody: `{"error":"invalid_request", "error_description":"access token missing"}`,
+ },
+ },
+ {
+ name: "unimplemented UserInfo called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("access_token=xxx")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ {
+ name: "bearer",
+ decoder: testDecoder,
+ r: func() *http.Request {
+ r := httptest.NewRequest(http.MethodGet, "/", nil)
+ r.Header.Set("authorization", strings.Join([]string{"Bearer", "xxx"}, " "))
+ return r
+ }(),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.userInfoHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_revocationHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ client Client
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ client: newClient(clientTypeWeb),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "token missing",
+ decoder: testDecoder,
+ client: newClient(clientTypeWeb),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"token missing"}`,
+ },
+ },
+ {
+ name: "unimplemented Revocation called, confidential client",
+ decoder: testDecoder,
+ client: newClient(clientTypeWeb),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ {
+ name: "unimplemented Revocation called, public client",
+ decoder: testDecoder,
+ client: newClient(clientTypeNative),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerClientTest(t, s.revocationHandler, tt.r, tt.client, tt.want)
+ })
+ }
+}
+
+func Test_webServer_endSessionHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "decoder error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error decoding form"}`,
+ },
+ },
+ {
+ name: "unimplemented EndSession called",
+ decoder: testDecoder,
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("id_token_hint=xxx")),
+ want: webServerResult{
+ wantStatus: UnimplementedStatusCode,
+ wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, s.endSessionHandler, tt.r, tt.want)
+ })
+ }
+}
+
+func Test_webServer_simpleHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ decoder httphelper.Decoder
+ method func(context.Context, *Request[struct{}]) (*Response, error)
+ r *http.Request
+ want webServerResult
+ }{
+ {
+ name: "parse error",
+ decoder: schema.NewDecoder(),
+ r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
+ want: webServerResult{
+ wantStatus: http.StatusBadRequest,
+ wantBody: `{"error":"invalid_request", "error_description":"error parsing form"}`,
+ },
+ },
+ {
+ name: "method error",
+ decoder: schema.NewDecoder(),
+ method: func(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ return nil, io.ErrClosedPipe
+ },
+ r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))),
+ want: webServerResult{
+ wantStatus: http.StatusInternalServerError,
+ wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := &webServer{
+ server: UnimplementedServer{},
+ decoder: tt.decoder,
+ logger: slog.Default(),
+ }
+ runWebServerTest(t, simpleHandler(s, tt.method), tt.r, tt.want)
+ })
+ }
+}
+
+func Test_decodeRequest(t *testing.T) {
+ type dst struct {
+ A string `schema:"a"`
+ B string `schema:"b"`
+ }
+ type args struct {
+ r *http.Request
+ postOnly bool
+ }
+ tests := []struct {
+ name string
+ args args
+ want *dst
+ wantErr error
+ }{
+ {
+ name: "parse error",
+ args: args{
+ r: httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(make([]byte, 11<<20))),
+ },
+ wantErr: oidc.ErrInvalidRequest().WithDescription("error parsing form"),
+ },
+ {
+ name: "decode error",
+ args: args{
+ r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
+ },
+ wantErr: oidc.ErrInvalidRequest().WithDescription("error decoding form"),
+ },
+ {
+ name: "success, get",
+ args: args{
+ r: httptest.NewRequest(http.MethodGet, "/?a=b&b=a", nil),
+ },
+ want: &dst{
+ A: "b",
+ B: "a",
+ },
+ },
+ {
+ name: "success, post only",
+ args: args{
+ r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")),
+ postOnly: true,
+ },
+ want: &dst{
+ A: "b",
+ },
+ },
+ {
+ name: "success, post mixed",
+ args: args{
+ r: httptest.NewRequest(http.MethodPost, "/?b=a", strings.NewReader("a=b&")),
+ postOnly: false,
+ },
+ want: &dst{
+ A: "b",
+ B: "a",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.args.r.Method == http.MethodPost {
+ tt.args.r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ }
+ got, err := decodeRequest[dst](schema.NewDecoder(), tt.args.r, tt.args.postOnly)
+ require.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/pkg/op/server_legacy.go b/pkg/op/server_legacy.go
new file mode 100644
index 0000000..06e4e93
--- /dev/null
+++ b/pkg/op/server_legacy.go
@@ -0,0 +1,457 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "time"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/go-chi/chi/v5"
+)
+
+// ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
+// so that its methods can be individually overridden.
+//
+// EXPERIMENTAL: may change until v4
+type ExtendedLegacyServer interface {
+ Server
+ Provider() OpenIDProvider
+ Endpoints() Endpoints
+ AuthCallbackURL() func(context.Context, string) string
+}
+
+// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
+// It takes care of registering the IssuerFromRequest middleware.
+// The authorizeCallbackHandler is registered on `/callback` under the authorization endpoint.
+// Neither are part of the bare [Server] interface.
+//
+// EXPERIMENTAL: may change until v4
+func RegisterLegacyServer(s ExtendedLegacyServer, authorizeCallbackHandler http.HandlerFunc, options ...ServerOption) http.Handler {
+ options = append(options,
+ WithHTTPMiddleware(intercept(s.Provider().IssuerFromRequest)),
+ WithSetRouter(func(r chi.Router) {
+ r.HandleFunc(s.Endpoints().Authorization.Relative()+authCallbackPathSuffix, authorizeCallbackHandler)
+ }),
+ )
+ return RegisterServer(s, s.Endpoints(), options...)
+}
+
+// LegacyServer is an implementation of [Server] that
+// simply wraps an [OpenIDProvider].
+// It can be used to transition from the former Provider/Storage
+// interfaces to the new Server interface.
+//
+// EXPERIMENTAL: may change until v4
+type LegacyServer struct {
+ UnimplementedServer
+ provider OpenIDProvider
+ endpoints Endpoints
+}
+
+// NewLegacyServer wraps provider in a `Server` implementation
+//
+// Only non-nil endpoints will be registered on the router.
+// Nil endpoints are disabled.
+//
+// The passed endpoints is also used for the discovery config,
+// and endpoints already set to the provider are ignored.
+// Any `With*Endpoint()` option used on the provider is
+// therefore ineffective.
+//
+// EXPERIMENTAL: may change until v4
+func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
+ return &LegacyServer{
+ provider: provider,
+ endpoints: endpoints,
+ }
+}
+
+func (s *LegacyServer) Provider() OpenIDProvider {
+ return s.provider
+}
+
+func (s *LegacyServer) Endpoints() Endpoints {
+ return s.endpoints
+}
+
+// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
+func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string {
+ return func(ctx context.Context, requestID string) string {
+ ctx, span := tracer.Start(ctx, "LegacyServer.AuthCallbackURL")
+ defer span.End()
+
+ return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
+ }
+}
+
+func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
+ return NewResponse(Status{Status: "ok"}), nil
+}
+
+func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ for _, probe := range s.provider.Probes() {
+ // shouldn't we run probes in Go routines?
+ if err := probe(ctx); err != nil {
+ return nil, AsStatusError(err, http.StatusInternalServerError)
+ }
+ }
+ return NewResponse(Status{Status: "ok"}), nil
+}
+
+func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.Discovery")
+ defer span.End()
+
+ return NewResponse(
+ createDiscoveryConfigV2(ctx, s.provider, s.provider.Storage(), &s.endpoints),
+ ), nil
+}
+
+func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.Keys")
+ defer span.End()
+
+ keys, err := s.provider.Storage().KeySet(ctx)
+ if err != nil {
+ return nil, AsStatusError(err, http.StatusInternalServerError)
+ }
+ return NewResponse(jsonWebKeySet(keys)), nil
+}
+
+var (
+ ErrAuthReqMissingClientID = errors.New("auth request is missing client_id")
+ ErrAuthReqMissingRedirectURI = errors.New("auth request is missing redirect_uri")
+)
+
+func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.AuthRequest]) (*ClientRequest[oidc.AuthRequest], error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.VerifyAuthRequest")
+ defer span.End()
+
+ if r.Data.RequestParam != "" {
+ if !s.provider.RequestObjectSupported() {
+ return nil, oidc.ErrRequestNotSupported()
+ }
+ err := ParseRequestObject(ctx, r.Data, s.provider.Storage(), IssuerFromContext(ctx))
+ if err != nil {
+ return nil, err
+ }
+ }
+ if r.Data.ClientID == "" {
+ return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error())
+ }
+ client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
+ if err != nil {
+ return nil, oidc.DefaultToServerError(err, "unable to retrieve client by id")
+ }
+
+ return &ClientRequest[oidc.AuthRequest]{
+ Request: r,
+ Client: client,
+ }, nil
+}
+
+func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.AuthRequest]) (_ *Redirect, err error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.Authorize")
+ defer span.End()
+
+ userID, err := ValidateAuthReqIDTokenHint(ctx, r.Data.IDTokenHint, s.provider.IDTokenHintVerifier(ctx))
+ if err != nil {
+ return nil, err
+ }
+ req, err := s.provider.Storage().CreateAuthRequest(ctx, r.Data, userID)
+ if err != nil {
+ return TryErrorRedirect(ctx, r.Data, oidc.DefaultToServerError(err, "unable to save auth request"), s.provider.Encoder(), s.provider.Logger())
+ }
+ return NewRedirect(r.Client.LoginURL(req.GetID())), nil
+}
+
+func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.DeviceAuthorization")
+ defer span.End()
+
+ response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
+ if err != nil {
+ return nil, AsStatusError(err, http.StatusInternalServerError)
+ }
+ return NewResponse(response), nil
+}
+
+func (s *LegacyServer) VerifyClient(ctx context.Context, r *Request[ClientCredentials]) (Client, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.VerifyClient")
+ defer span.End()
+
+ if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
+ storage, ok := s.provider.Storage().(ClientCredentialsStorage)
+ if !ok {
+ return nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
+ }
+ return storage.ClientCredentials(ctx, r.Data.ClientID, r.Data.ClientSecret)
+ }
+
+ if r.Data.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
+ jwtExchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
+ if !ok || !s.provider.AuthMethodPrivateKeyJWTSupported() {
+ return nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
+ }
+ return AuthorizePrivateJWTKey(ctx, r.Data.ClientAssertion, jwtExchanger)
+ }
+ client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithParent(err)
+ }
+
+ switch client.AuthMethod() {
+ case oidc.AuthMethodNone:
+ return client, nil
+ case oidc.AuthMethodPrivateKeyJWT:
+ return nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
+ case oidc.AuthMethodPost:
+ if !s.provider.AuthMethodPostSupported() {
+ return nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
+ }
+ }
+
+ err = AuthorizeClientIDSecret(ctx, r.Data.ClientID, r.Data.ClientSecret, s.provider.Storage())
+ if err != nil {
+ return nil, err
+ }
+
+ return client, nil
+}
+
+func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.AccessTokenRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.CodeExchange")
+ defer span.End()
+
+ authReq, err := AuthRequestByCode(ctx, s.provider.Storage(), r.Data.Code)
+ if err != nil {
+ return nil, err
+ }
+ if r.Client.AuthMethod() == oidc.AuthMethodNone || r.Data.CodeVerifier != "" {
+ if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
+ return nil, err
+ }
+ }
+ if r.Data.RedirectURI != authReq.GetRedirectURI() {
+ return nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
+ }
+ resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) RefreshToken(ctx context.Context, r *ClientRequest[oidc.RefreshTokenRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.RefreshToken")
+ defer span.End()
+
+ if !s.provider.GrantTypeRefreshTokenSupported() {
+ return nil, unimplementedGrantError(oidc.GrantTypeRefreshToken)
+ }
+ request, err := RefreshTokenRequestByRefreshToken(ctx, s.provider.Storage(), r.Data.RefreshToken)
+ if err != nil {
+ return nil, err
+ }
+ if r.Client.GetID() != request.GetClientID() {
+ return nil, oidc.ErrInvalidGrant()
+ }
+ if err = ValidateRefreshTokenScopes(r.Data.Scopes, request); err != nil {
+ return nil, err
+ }
+ resp, err := CreateTokenResponse(ctx, request, r.Client, s.provider, true, "", r.Data.RefreshToken)
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfileGrantRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.JWTProfile")
+ defer span.End()
+
+ exchanger, ok := s.provider.(JWTAuthorizationGrantExchanger)
+ if !ok {
+ return nil, unimplementedGrantError(oidc.GrantTypeBearer)
+ }
+ tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid")
+ }
+
+ tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := CreateJWTTokenResponse(ctx, tokenRequest, exchanger)
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) TokenExchange(ctx context.Context, r *ClientRequest[oidc.TokenExchangeRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.TokenExchange")
+ defer span.End()
+
+ if !s.provider.GrantTypeTokenExchangeSupported() {
+ return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
+ }
+ tokenExchangeRequest, err := CreateTokenExchangeRequest(ctx, r.Data, r.Client, s.provider)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := CreateTokenExchangeResponse(ctx, tokenExchangeRequest, r.Client, s.provider)
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientRequest[oidc.ClientCredentialsRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.ClientCredentialsExchange")
+ defer span.End()
+
+ storage, ok := s.provider.Storage().(ClientCredentialsStorage)
+ if !ok {
+ return nil, unimplementedGrantError(oidc.GrantTypeClientCredentials)
+ }
+ tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, r.Client.GetID(), r.Data.Scope)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := CreateClientCredentialsTokenResponse(ctx, tokenRequest, s.provider, r.Client)
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.DeviceToken")
+ defer span.End()
+
+ if !s.provider.GrantTypeDeviceCodeSupported() {
+ return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
+ }
+ // use a limited context timeout shorter as the default
+ // poll interval of 5 seconds.
+ ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
+ defer cancel()
+
+ tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
+ if err != nil {
+ return nil, err
+ }
+ resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
+ if err != nil {
+ return nil, err
+ }
+ return NewResponse(resp), nil
+}
+
+func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.authenticateResourceClient")
+ defer span.End()
+
+ if cc.ClientAssertion != "" {
+ if jp, ok := s.provider.(ClientJWTProfile); ok {
+ return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp)
+ }
+ return "", oidc.ErrInvalidClient().WithDescription("client_assertion not supported")
+ }
+ if err := s.provider.Storage().AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
+ return "", oidc.ErrUnauthorizedClient().WithParent(err)
+ }
+ return cc.ClientID, nil
+}
+
+func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.Introspect")
+ defer span.End()
+
+ clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
+ if err != nil {
+ return nil, err
+ }
+ response := new(oidc.IntrospectionResponse)
+ tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
+ if !ok {
+ return NewResponse(response), nil
+ }
+ err = s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
+ if err != nil {
+ return NewResponse(response), nil
+ }
+ response.Active = true
+ return NewResponse(response), nil
+}
+
+func (s *LegacyServer) UserInfo(ctx context.Context, r *Request[oidc.UserInfoRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.UserInfo")
+ defer span.End()
+
+ tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.AccessToken)
+ if !ok {
+ return nil, NewStatusError(oidc.ErrAccessDenied().WithDescription("access token invalid"), http.StatusUnauthorized)
+ }
+ info := new(oidc.UserInfo)
+ err := s.provider.Storage().SetUserinfoFromToken(ctx, info, tokenID, subject, r.Header.Get("origin"))
+ if err != nil {
+ return nil, NewStatusError(err, http.StatusForbidden)
+ }
+ return NewResponse(info), nil
+}
+
+func (s *LegacyServer) Revocation(ctx context.Context, r *ClientRequest[oidc.RevocationRequest]) (*Response, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.Revocation")
+ defer span.End()
+
+ var subject string
+ doDecrypt := true
+ if r.Data.TokenTypeHint != "access_token" {
+ userID, tokenID, err := s.provider.Storage().GetRefreshTokenInfo(ctx, r.Client.GetID(), r.Data.Token)
+ if err != nil {
+ // An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
+ if !errors.Is(err, ErrInvalidRefreshToken) {
+ return nil, RevocationError(oidc.ErrServerError().WithParent(err))
+ }
+ } else {
+ r.Data.Token = tokenID
+ subject = userID
+ doDecrypt = false
+ }
+ }
+ if doDecrypt {
+ tokenID, userID, ok := getTokenIDAndSubjectForRevocation(ctx, s.provider, r.Data.Token)
+ if ok {
+ r.Data.Token = tokenID
+ subject = userID
+ }
+ }
+ if err := s.provider.Storage().RevokeToken(ctx, r.Data.Token, subject, r.Client.GetID()); err != nil {
+ return nil, RevocationError(err)
+ }
+ return NewResponse(nil), nil
+}
+
+func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessionRequest]) (*Redirect, error) {
+ ctx, span := tracer.Start(ctx, "LegacyServer.EndSession")
+ defer span.End()
+
+ session, err := ValidateEndSessionRequest(ctx, r.Data, s.provider)
+ if err != nil {
+ return nil, err
+ }
+ redirect := session.RedirectURI
+ if fromRequest, ok := s.provider.Storage().(CanTerminateSessionFromRequest); ok {
+ redirect, err = fromRequest.TerminateSessionFromRequest(ctx, session)
+ } else {
+ err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return NewRedirect(redirect), nil
+}
diff --git a/pkg/op/server_test.go b/pkg/op/server_test.go
new file mode 100644
index 0000000..0cad8fd
--- /dev/null
+++ b/pkg/op/server_test.go
@@ -0,0 +1,5 @@
+package op
+
+// implementation check
+var _ Server = &UnimplementedServer{}
+var _ Server = &LegacyServer{}
diff --git a/pkg/op/session.go b/pkg/op/session.go
index c274bf0..ac663c9 100644
--- a/pkg/op/session.go
+++ b/pkg/op/session.go
@@ -2,21 +2,35 @@ package op
import (
"context"
+ "errors"
+ "log/slog"
"net/http"
+ "net/url"
+ "path"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/rp"
- "github.com/gorilla/schema"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
type SessionEnder interface {
- Decoder() *schema.Decoder
+ Decoder() httphelper.Decoder
Storage() Storage
- IDTokenVerifier() rp.Verifier
+ IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string
+ Logger() *slog.Logger
+}
+
+func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ EndSession(w, r, ender)
+ }
}
func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
+ ctx, span := tracer.Start(r.Context(), "EndSession")
+ defer span.End()
+ r = r.WithContext(ctx)
+
req, err := ParseEndSessionRequest(r, ender.Decoder())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -24,67 +38,95 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
}
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
if err != nil {
- RequestError(w, r, err)
+ RequestError(w, r, err, ender.Logger())
return
}
- var clientID string
- if session.Client != nil {
- clientID = session.Client.GetID()
+ redirect := session.RedirectURI
+ if fromRequest, ok := ender.Storage().(CanTerminateSessionFromRequest); ok {
+ redirect, err = fromRequest.TerminateSessionFromRequest(r.Context(), session)
+ } else {
+ err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
}
- err = ender.Storage().TerminateSession(r.Context(), session.UserID, clientID)
if err != nil {
- RequestError(w, r, ErrServerError("error terminating session"))
+ RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
return
}
- http.Redirect(w, r, session.RedirectURI, http.StatusFound)
+ http.Redirect(w, r, redirect, http.StatusFound)
}
-func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) {
+func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm()
if err != nil {
- return nil, ErrInvalidRequest("error parsing form")
+ return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
req := new(oidc.EndSessionRequest)
err = decoder.Decode(req, r.Form)
if err != nil {
- return nil, ErrInvalidRequest("error decoding form")
+ return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return req, nil
}
func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
- session := new(EndSessionRequest)
- if req.IdTokenHint == "" {
- return session, nil
+ ctx, span := tracer.Start(ctx, "ValidateEndSessionRequest")
+ defer span.End()
+
+ session := &EndSessionRequest{
+ RedirectURI: ender.DefaultLogoutRedirectURI(),
+ LogoutHint: req.LogoutHint,
+ UILocales: req.UILocales,
}
- claims, err := ender.IDTokenVerifier().Verify(ctx, "", req.IdTokenHint)
- if err != nil {
- return nil, ErrInvalidRequest("id_token_hint invalid")
+ if req.IdTokenHint != "" {
+ claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
+ if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
+ return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
+ }
+ session.UserID = claims.GetSubject()
+ session.IDTokenHintClaims = claims
+ if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() {
+ return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint")
+ }
+ req.ClientID = claims.GetAuthorizedParty()
}
- session.UserID = claims.Subject
- session.Client, err = ender.Storage().GetClientByClientID(ctx, claims.AuthorizedParty)
- if err != nil {
- return nil, ErrServerError("")
- }
- if req.PostLogoutRedirectURI == "" {
- session.RedirectURI = ender.DefaultLogoutRedirectURI()
- return session, nil
- }
- for _, uri := range session.Client.PostLogoutRedirectURIs() {
- if uri == req.PostLogoutRedirectURI {
- session.RedirectURI = uri + "?state=" + req.State
- return session, nil
+ if req.ClientID != "" {
+ client, err := ender.Storage().GetClientByClientID(ctx, req.ClientID)
+ if err != nil {
+ return nil, oidc.DefaultToServerError(err, "")
+ }
+ session.ClientID = client.GetID()
+ if req.PostLogoutRedirectURI != "" {
+ if err := ValidateEndSessionPostLogoutRedirectURI(req.PostLogoutRedirectURI, client); err != nil {
+ return nil, err
+ }
+ session.RedirectURI = req.PostLogoutRedirectURI
}
}
- return nil, ErrInvalidRequest("post_logout_redirect_uri invalid")
+ if req.State != "" {
+ redirect, err := url.Parse(session.RedirectURI)
+ if err != nil {
+ return nil, oidc.DefaultToServerError(err, "")
+ }
+ session.RedirectURI = mergeQueryParams(redirect, url.Values{"state": {req.State}})
+ }
+ return session, nil
}
-func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
- if authRequest == nil {
- return true
+func ValidateEndSessionPostLogoutRedirectURI(postLogoutRedirectURI string, client Client) error {
+ for _, uri := range client.PostLogoutRedirectURIs() {
+ if uri == postLogoutRedirectURI {
+ return nil
+ }
}
- if authRequest.Prompt == oidc.PromptNone {
- return true
+ if globClient, ok := client.(HasRedirectGlobs); ok {
+ for _, uriGlob := range globClient.PostLogoutRedirectURIGlobs() {
+ isMatch, err := path.Match(uriGlob, postLogoutRedirectURI)
+ if err != nil {
+ return oidc.ErrServerError().WithParent(err)
+ }
+ if isMatch {
+ return nil
+ }
+ }
}
- return false
+ return oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
}
diff --git a/pkg/op/signer.go b/pkg/op/signer.go
index b4f770e..5c3dd6a 100644
--- a/pkg/op/signer.go
+++ b/pkg/op/signer.go
@@ -1,84 +1,36 @@
package op
import (
- "encoding/json"
"errors"
- "golang.org/x/net/context"
- "gopkg.in/square/go-jose.v2"
-
- "github.com/caos/logging"
- "github.com/caos/oidc/pkg/oidc"
+ jose "github.com/go-jose/go-jose/v4"
)
-type Signer interface {
- Health(ctx context.Context) error
- SignIDToken(claims *oidc.IDTokenClaims) (string, error)
- SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
+var ErrSignerCreationFailed = errors.New("signer creation failed")
+
+type SigningKey interface {
SignatureAlgorithm() jose.SignatureAlgorithm
+ Key() any
+ ID() string
}
-type tokenSigner struct {
- signer jose.Signer
- storage AuthStorage
- alg jose.SignatureAlgorithm
-}
-
-func NewDefaultSigner(ctx context.Context, storage AuthStorage, keyCh <-chan jose.SigningKey) Signer {
- s := &tokenSigner{
- storage: storage,
- }
-
- go s.refreshSigningKey(ctx, keyCh)
-
- return s
-}
-
-func (s *tokenSigner) Health(_ context.Context) error {
- if s.signer == nil {
- return errors.New("no signer")
- }
- return nil
-}
-
-func (s *tokenSigner) refreshSigningKey(ctx context.Context, keyCh <-chan jose.SigningKey) {
- for {
- select {
- case <-ctx.Done():
- return
- case key := <-keyCh:
- s.alg = key.Algorithm
- var err error
- s.signer, err = jose.NewSigner(key, &jose.SignerOptions{})
- logging.Log("OP-pf32aw").OnError(err).Error("error creating signer")
- }
- }
-}
-
-func (s *tokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
+func SignerFromKey(key SigningKey) (jose.Signer, error) {
+ signer, err := jose.NewSigner(jose.SigningKey{
+ Algorithm: key.SignatureAlgorithm(),
+ Key: &jose.JSONWebKey{
+ Key: key.Key(),
+ KeyID: key.ID(),
+ },
+ }, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
- return "", err
+ return nil, ErrSignerCreationFailed // TODO: log / wrap error?
}
- return s.Sign(payload)
+ return signer, nil
}
-func (s *tokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
- payload, err := json.Marshal(claims)
- if err != nil {
- return "", err
- }
- return s.Sign(payload)
-}
-
-func (s *tokenSigner) Sign(payload []byte) (string, error) {
- result, err := s.signer.Sign(payload)
- if err != nil {
- return "", err
- }
- return result.CompactSerialize()
-}
-
-func (s *tokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
- return s.alg
+type Key interface {
+ ID() string
+ Algorithm() jose.SignatureAlgorithm
+ Use() string
+ Key() any
}
diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go
deleted file mode 100644
index 75e184b..0000000
--- a/pkg/op/signer_test.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package op
-
-import (
- "testing"
-
- "github.com/stretchr/testify/require"
- "gopkg.in/square/go-jose.v2"
-)
-
-// func TestNewDefaultSigner(t *testing.T) {
-// type args struct {
-// storage Storage
-// }
-// tests := []struct {
-// name string
-// args args
-// want Signer
-// wantErr bool
-// }{
-// {
-// "err initialize storage fails",
-// args{mock.NewMockStorageSigningKeyError(t)},
-// nil,
-// true,
-// },
-// {
-// "err initialize storage fails",
-// args{mock.NewMockStorageSigningKeyInvalid(t)},
-// nil,
-// true,
-// },
-// {
-// "initialize ok",
-// args{mock.NewMockStorageSigningKey(t)},
-// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
-// false,
-// },
-// }
-// for _, tt := range tests {
-// t.Run(tt.name, func(t *testing.T) {
-// got, err := op.NewDefaultSigner(tt.args.storage)
-// if (err != nil) != tt.wantErr {
-// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
-// return
-// }
-// if !reflect.DeepEqual(got, tt.want) {
-// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
-// }
-// })
-// }
-// }
-
-func Test_idTokenSigner_Sign(t *testing.T) {
- signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
- require.NoError(t, err)
-
- type fields struct {
- signer jose.Signer
- storage Storage
- }
- type args struct {
- payload []byte
- }
- tests := []struct {
- name string
- fields fields
- args args
- want string
- wantErr bool
- }{
- {
- "ok",
- fields{signer, nil},
- args{[]byte("test")},
- "eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
- false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- s := &tokenSigner{
- signer: tt.fields.signer,
- storage: tt.fields.storage,
- }
- got, err := s.Sign(tt.args.payload)
- if (err != nil) != tt.wantErr {
- t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
- }
- })
- }
-}
diff --git a/pkg/op/storage.go b/pkg/op/storage.go
index e3ef5ff..2dbd124 100644
--- a/pkg/op/storage.go
+++ b/pkg/op/storage.go
@@ -2,11 +2,13 @@ package op
import (
"context"
+ "errors"
"time"
- "gopkg.in/square/go-jose.v2"
+ jose "github.com/go-jose/go-jose/v4"
+ "golang.org/x/text/language"
- "github.com/caos/oidc/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
type AuthStorage interface {
@@ -16,22 +18,144 @@ type AuthStorage interface {
SaveAuthCode(context.Context, string, string) error
DeleteAuthRequest(context.Context, string) error
- CreateToken(context.Context, AuthRequest) (string, time.Time, error)
+ // The TokenRequest parameter of CreateAccessToken can be any of:
+ //
+ // * TokenRequest as returned by ClientCredentialsStorage.ClientCredentialsTokenRequest,
+ //
+ // * AuthRequest as returned by AuthRequestByID or AuthRequestByCode (above)
+ //
+ // * *oidc.JWTTokenRequest from a JWT that is the assertion value of a JWT Profile
+ // Grant: https://datatracker.ietf.org/doc/html/rfc7523#section-2.1
+ //
+ // * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
+ CreateAccessToken(context.Context, TokenRequest) (accessTokenID string, expiration time.Time, err error)
- TerminateSession(context.Context, string, string) error
+ // The TokenRequest parameter of CreateAccessAndRefreshTokens can be any of:
+ //
+ // * TokenRequest as returned by ClientCredentialsStorage.ClientCredentialsTokenRequest
+ //
+ // * RefreshTokenRequest as returned by AuthStorage.TokenRequestByRefreshToken
+ //
+ // * AuthRequest as by returned by the AuthRequestByID or AuthRequestByCode (above).
+ // Used for the authorization code flow which requested offline_access scope and
+ // registered the refresh_token grant type in advance
+ //
+ // * TokenExchangeRequest as returned by ValidateTokenExchangeRequest
+ CreateAccessAndRefreshTokens(ctx context.Context, request TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshTokenID string, expiration time.Time, err error)
+ TokenRequestByRefreshToken(ctx context.Context, refreshTokenID string) (RefreshTokenRequest, error)
- GetSigningKey(context.Context, chan<- jose.SigningKey, chan<- error, <-chan time.Time)
- GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
- SaveNewKeyPair(context.Context) error
+ TerminateSession(ctx context.Context, userID string, clientID string) error
+
+ // RevokeToken should revoke a token. In the situation that the original request was to
+ // revoke an access token, then tokenOrTokenID will be a tokenID and userID will be set
+ // but if the original request was for a refresh token, then userID will be empty and
+ // tokenOrTokenID will be the refresh token, not its ID. RevokeToken depends upon GetRefreshTokenInfo
+ // to get information from refresh tokens that are not either ":" strings
+ // nor JWTs.
+ RevokeToken(ctx context.Context, tokenOrTokenID string, userID string, clientID string) *oidc.Error
+
+ // GetRefreshTokenInfo must return ErrInvalidRefreshToken when presented
+ // with a token that is not a refresh token.
+ GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error)
+
+ SigningKey(context.Context) (SigningKey, error)
+ SignatureAlgorithms(context.Context) ([]jose.SignatureAlgorithm, error)
+ KeySet(context.Context) ([]Key, error)
}
+// CanTerminateSessionFromRequest is an optional additional interface that may be implemented by
+// implementors of Storage as an alternative to TerminateSession of the AuthStorage.
+// It passes the complete parsed EndSessionRequest to the implementation, which allows access to additional data.
+// It also allows to modify the uri, which will be used for redirection, (e.g. a UI where the user can consent to the logout)
+type CanTerminateSessionFromRequest interface {
+ TerminateSessionFromRequest(ctx context.Context, endSessionRequest *EndSessionRequest) (string, error)
+}
+
+type ClientCredentialsStorage interface {
+ ClientCredentials(ctx context.Context, clientID, clientSecret string) (Client, error)
+ ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (TokenRequest, error)
+}
+
+type TokenExchangeStorage interface {
+ // ValidateTokenExchangeRequest will be called to validate parsed (including tokens) Token Exchange Grant request.
+ //
+ // Important validations can include:
+ // - permissions
+ // - set requested token type to some default value if it is empty (rfc 8693 allows it) using SetRequestedTokenType method.
+ // Depending on RequestedTokenType - the following tokens will be issued:
+ // - RefreshTokenType - both access and refresh tokens
+ // - AccessTokenType - only access token
+ // - IDTokenType - only id token
+ // - validation of subject's token type on possibility to be exchanged to the requested token type (according to your requirements)
+ // - scopes (and update them using SetCurrentScopes method)
+ // - set new subject if it differs from exchange subject (impersonation flow)
+ //
+ // Request will include subject's and/or actor's token claims if correspinding tokens are access/id_token issued by op
+ // or third party tokens parsed by TokenExchangeTokensVerifierStorage interface methods.
+ ValidateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
+
+ // CreateTokenExchangeRequest will be called after parsing and validating token exchange request.
+ // Stored request is not accessed later by op - so it is up to implementer to decide
+ // should this method actually store the request or not (common use case - store for it for audit purposes)
+ CreateTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) error
+
+ // GetPrivateClaimsFromTokenExchangeRequest will be called during access token creation.
+ // Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
+ GetPrivateClaimsFromTokenExchangeRequest(ctx context.Context, request TokenExchangeRequest) (claims map[string]any, err error)
+
+ // SetUserinfoFromTokenExchangeRequest will be called during id token creation.
+ // Claims evaluation can be based on all validated request data available, including: scopes, resource, audience, etc.
+ SetUserinfoFromTokenExchangeRequest(ctx context.Context, userinfo *oidc.UserInfo, request TokenExchangeRequest) error
+}
+
+// TokenExchangeTokensVerifierStorage is an optional interface used in token exchange process to verify tokens
+// issued by third-party applications. If interface is not implemented - only tokens issued by op will be exchanged.
+type TokenExchangeTokensVerifierStorage interface {
+ VerifyExchangeSubjectToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, subject string, tokenClaims map[string]any, err error)
+ VerifyExchangeActorToken(ctx context.Context, token string, tokenType oidc.TokenType) (tokenIDOrToken string, actor string, tokenClaims map[string]any, err error)
+}
+
+var ErrInvalidRefreshToken = errors.New("invalid_refresh_token")
+
type OPStorage interface {
- GetClientByClientID(context.Context, string) (Client, error)
- AuthorizeClientIDSecret(context.Context, string, string) error
- GetUserinfoFromScopes(context.Context, string, []string) (*oidc.Userinfo, error)
- GetUserinfoFromToken(context.Context, string) (*oidc.Userinfo, error)
+ // GetClientByClientID loads a Client. The returned Client is never cached and is only used to
+ // handle the current request.
+ GetClientByClientID(ctx context.Context, clientID string) (Client, error)
+ AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error
+ // SetUserinfoFromScopes is deprecated and should have an empty implementation for now.
+ // Implement SetUserinfoFromRequest instead.
+ SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error
+ SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error
+ SetIntrospectionFromToken(ctx context.Context, userinfo *oidc.IntrospectionResponse, tokenID, subject, clientID string) error
+ GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error)
+ GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error)
+ ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error)
}
+// JWTProfileTokenStorage is an additional, optional storage to implement
+// implementing it, allows specifying the [AccessTokenType] of the access_token returned form the JWT Profile TokenRequest
+type JWTProfileTokenStorage interface {
+ JWTProfileTokenType(ctx context.Context, request TokenRequest) (AccessTokenType, error)
+}
+
+// CanSetUserinfoFromRequest is an optional additional interface that may be implemented by
+// implementors of Storage. It allows additional data to be set in id_tokens based on the
+// request.
+type CanSetUserinfoFromRequest interface {
+ SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request IDTokenRequest, scopes []string) error
+}
+
+// CanGetPrivateClaimsFromRequest is an optional additional interface that may be implemented by
+// implementors of Storage. It allows setting the jwt token claims based on the request.
+type CanGetPrivateClaimsFromRequest interface {
+ GetPrivateClaimsFromRequest(ctx context.Context, request TokenRequest, restrictedScopes []string) (map[string]any, error)
+}
+
+// Storage is a required parameter for NewOpenIDProvider(). In addition to the
+// embedded interfaces below, if the passed Storage implements ClientCredentialsStorage
+// then the grant type "client_credentials" will be supported. In that case, the access
+// token returned by CreateAccessToken should be a JWT.
+// See https://datatracker.ietf.org/doc/html/rfc6749#section-1.3.4 for context.
type Storage interface {
AuthStorage
OPStorage
@@ -42,25 +166,37 @@ type StorageNotFoundError interface {
IsNotFound()
}
-type AuthRequest interface {
- GetID() string
- GetACR() string
- GetAMR() []string
- GetAudience() []string
- GetAuthTime() time.Time
- GetClientID() string
- GetCodeChallenge() *oidc.CodeChallenge
- GetNonce() string
- GetRedirectURI() string
- GetResponseType() oidc.ResponseType
- GetScopes() []string
- GetState() string
- GetSubject() string
- Done() bool
+type EndSessionRequest struct {
+ UserID string
+ ClientID string
+ IDTokenHintClaims *oidc.IDTokenClaims
+ RedirectURI string
+ LogoutHint string
+ UILocales []language.Tag
}
-type EndSessionRequest struct {
- UserID string
- Client Client
- RedirectURI string
+var ErrDuplicateUserCode = errors.New("user code already exists")
+
+type DeviceAuthorizationStorage interface {
+ // StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
+ // User code will be used by the user to complete the login flow and must be unique.
+ // ErrDuplicateUserCode signals the caller should try again with a new code.
+ //
+ // Note that user codes are low entropy keys and when many exist in the
+ // database, the change for collisions increases. Therefore implementers
+ // of this interface must make sure that user codes of expired authentication flows are purged,
+ // after some time.
+ StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error
+
+ // GetDeviceAuthorizatonState returns the current state of the device authorization flow in the database.
+ // The method is polled untill the the authorization is eighter Completed, Expired or Denied.
+ GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (*DeviceAuthorizationState, error)
+}
+
+func assertDeviceStorage(s Storage) (DeviceAuthorizationStorage, error) {
+ storage, ok := s.(DeviceAuthorizationStorage)
+ if !ok {
+ return nil, oidc.ErrUnsupportedGrantType().WithDescription("device_code grant not supported")
+ }
+ return storage, nil
}
diff --git a/pkg/op/token.go b/pkg/op/token.go
index 9d37788..2e25d05 100644
--- a/pkg/op/token.go
+++ b/pkg/op/token.go
@@ -2,111 +2,266 @@ package op
import (
"context"
+ "slices"
"time"
- "github.com/caos/oidc/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/crypto"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
type TokenCreator interface {
- Issuer() string
- Signer() Signer
Storage() Storage
Crypto() Crypto
}
-func CreateTokenResponse(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
- var accessToken string
+type TokenRequest interface {
+ GetSubject() string
+ GetAudience() []string
+ GetScopes() []string
+}
+
+type AccessTokenClient interface {
+ GetID() string
+ ClockSkew() time.Duration
+ RestrictAdditionalAccessTokenScopes() func(scopes []string) []string
+ GrantTypes() []oidc.GrantType
+}
+
+func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Client, creator TokenCreator, createAccessToken bool, code, refreshToken string) (*oidc.AccessTokenResponse, error) {
+ ctx, span := tracer.Start(ctx, "CreateTokenResponse")
+ defer span.End()
+
+ var accessToken, newRefreshToken string
var validity time.Duration
if createAccessToken {
var err error
- accessToken, validity, err = CreateAccessToken(ctx, authReq, client, creator)
+ accessToken, newRefreshToken, validity, err = CreateAccessToken(ctx, request, client.AccessTokenType(), creator, client, refreshToken)
if err != nil {
return nil, err
}
}
- idToken, err := CreateIDToken(ctx, creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Storage(), creator.Signer())
+ idToken, err := CreateIDToken(ctx, IssuerFromContext(ctx), request, client.IDTokenLifetime(), accessToken, code, creator.Storage(), client)
if err != nil {
return nil, err
}
- err = creator.Storage().DeleteAuthRequest(ctx, authReq.GetID())
- if err != nil {
- return nil, err
+ var state string
+ if authRequest, ok := request.(AuthRequest); ok {
+ err = creator.Storage().DeleteAuthRequest(ctx, authRequest.GetID())
+ if err != nil {
+ return nil, err
+ }
+ // only implicit flow requires state to be returned.
+ if code == "" {
+ state = authRequest.GetState()
+ }
}
exp := uint64(validity.Seconds())
return &oidc.AccessTokenResponse{
- AccessToken: accessToken,
- IDToken: idToken,
- TokenType: oidc.BearerToken,
- ExpiresIn: exp,
+ AccessToken: accessToken,
+ IDToken: idToken,
+ RefreshToken: newRefreshToken,
+ TokenType: oidc.BearerToken,
+ ExpiresIn: exp,
+ State: state,
+ Scope: request.GetScopes(),
}, nil
}
-func CreateAccessToken(ctx context.Context, authReq AuthRequest, client Client, creator TokenCreator) (token string, validity time.Duration, err error) {
- id, exp, err := creator.Storage().CreateToken(ctx, authReq)
- if err != nil {
- return "", 0, err
+func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storage, refreshToken string, client AccessTokenClient) (id, newRefreshToken string, exp time.Time, err error) {
+ ctx, span := tracer.Start(ctx, "createTokens")
+ defer span.End()
+
+ if needsRefreshToken(tokenRequest, client) {
+ return storage.CreateAccessAndRefreshTokens(ctx, tokenRequest, refreshToken)
}
- validity = exp.Sub(time.Now().UTC())
- if client.AccessTokenType() == AccessTokenTypeJWT {
- token, err = CreateJWT(creator.Issuer(), authReq, exp, id, creator.Signer())
- return
- }
- token, err = CreateBearerToken(id, creator.Crypto())
+ id, exp, err = storage.CreateAccessToken(ctx, tokenRequest)
return
}
-func CreateBearerToken(id string, crypto Crypto) (string, error) {
- return crypto.Encrypt(id)
-}
-
-func CreateJWT(issuer string, authReq AuthRequest, exp time.Time, id string, signer Signer) (string, error) {
- now := time.Now().UTC()
- nbf := now
- claims := &oidc.AccessTokenClaims{
- Issuer: issuer,
- Subject: authReq.GetSubject(),
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: now,
- NotBefore: nbf,
- JWTID: id,
+func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool {
+ switch req := tokenRequest.(type) {
+ case AuthRequest:
+ return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
+ case TokenExchangeRequest:
+ return req.GetRequestedTokenType() == oidc.RefreshTokenType
+ case RefreshTokenRequest:
+ return true
+ case *DeviceAuthorizationState:
+ return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
+ default:
+ return false
}
- return signer.SignAccessToken(claims)
}
-func CreateIDToken(ctx context.Context, issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, storage Storage, signer Signer) (string, error) {
- var err error
- exp := time.Now().UTC().Add(validity)
- userinfo, err := storage.GetUserinfoFromScopes(ctx, authReq.GetSubject(), authReq.GetScopes())
+func CreateAccessToken(ctx context.Context, tokenRequest TokenRequest, accessTokenType AccessTokenType, creator TokenCreator, client AccessTokenClient, refreshToken string) (accessToken, newRefreshToken string, validity time.Duration, err error) {
+ ctx, span := tracer.Start(ctx, "CreateAccessToken")
+ defer span.End()
+
+ id, newRefreshToken, exp, err := createTokens(ctx, tokenRequest, creator.Storage(), refreshToken, client)
if err != nil {
-
+ return "", "", 0, err
}
- claims := &oidc.IDTokenClaims{
- Issuer: issuer,
- Audiences: authReq.GetAudience(),
- Expiration: exp,
- IssuedAt: time.Now().UTC(),
- AuthTime: authReq.GetAuthTime(),
- Nonce: authReq.GetNonce(),
- AuthenticationContextClassReference: authReq.GetACR(),
- AuthenticationMethodsReferences: authReq.GetAMR(),
- AuthorizedParty: authReq.GetClientID(),
- Userinfo: *userinfo,
+ var clockSkew time.Duration
+ if client != nil {
+ clockSkew = client.ClockSkew()
+ }
+ validity = exp.Add(clockSkew).Sub(time.Now().UTC())
+ if accessTokenType == AccessTokenTypeJWT {
+ accessToken, err = CreateJWT(ctx, IssuerFromContext(ctx), tokenRequest, exp, id, client, creator.Storage())
+ return
+ }
+ _, span = tracer.Start(ctx, "CreateBearerToken")
+ accessToken, err = CreateBearerToken(id, tokenRequest.GetSubject(), creator.Crypto())
+ span.End()
+ return
+}
+
+func CreateBearerToken(tokenID, subject string, crypto Crypto) (string, error) {
+ return crypto.Encrypt(tokenID + ":" + subject)
+}
+
+type TokenActorRequest interface {
+ GetActor() *oidc.ActorClaims
+}
+
+func CreateJWT(ctx context.Context, issuer string, tokenRequest TokenRequest, exp time.Time, id string, client AccessTokenClient, storage Storage) (string, error) {
+ ctx, span := tracer.Start(ctx, "CreateJWT")
+ defer span.End()
+
+ claims := oidc.NewAccessTokenClaims(issuer, tokenRequest.GetSubject(), tokenRequest.GetAudience(), exp, id, client.GetID(), client.ClockSkew())
+ if client != nil {
+ restrictedScopes := client.RestrictAdditionalAccessTokenScopes()(tokenRequest.GetScopes())
+
+ var (
+ privateClaims map[string]any
+ err error
+ )
+
+ tokenExchangeRequest, okReq := tokenRequest.(TokenExchangeRequest)
+ teStorage, okStorage := storage.(TokenExchangeStorage)
+ if okReq && okStorage {
+ privateClaims, err = teStorage.GetPrivateClaimsFromTokenExchangeRequest(
+ ctx,
+ tokenExchangeRequest,
+ )
+ } else {
+ if fromRequest, ok := storage.(CanGetPrivateClaimsFromRequest); ok {
+ privateClaims, err = fromRequest.GetPrivateClaimsFromRequest(ctx, tokenRequest, removeUserinfoScopes(restrictedScopes))
+ } else {
+ privateClaims, err = storage.GetPrivateClaimsFromScopes(ctx, tokenRequest.GetSubject(), client.GetID(), removeUserinfoScopes(restrictedScopes))
+ }
+ }
+
+ if err != nil {
+ return "", err
+ }
+ claims.Claims = privateClaims
+ }
+ if actorReq, ok := tokenRequest.(TokenActorRequest); ok {
+ claims.Actor = actorReq.GetActor()
+ }
+ signingKey, err := storage.SigningKey(ctx)
+ if err != nil {
+ return "", err
+ }
+ signer, err := SignerFromKey(signingKey)
+ if err != nil {
+ return "", err
+ }
+ return crypto.Sign(claims, signer)
+}
+
+type IDTokenRequest interface {
+ GetAMR() []string
+ GetAudience() []string
+ GetAuthTime() time.Time
+ GetClientID() string
+ GetScopes() []string
+ GetSubject() string
+}
+
+func CreateIDToken(ctx context.Context, issuer string, request IDTokenRequest, validity time.Duration, accessToken, code string, storage Storage, client Client) (string, error) {
+ ctx, span := tracer.Start(ctx, "CreateIDToken")
+ defer span.End()
+
+ exp := time.Now().UTC().Add(client.ClockSkew()).Add(validity)
+ var acr, nonce string
+ if authRequest, ok := request.(AuthRequest); ok {
+ acr = authRequest.GetACR()
+ nonce = authRequest.GetNonce()
+ }
+ claims := oidc.NewIDTokenClaims(issuer, request.GetSubject(), request.GetAudience(), exp, request.GetAuthTime(), nonce, acr, request.GetAMR(), request.GetClientID(), client.ClockSkew())
+ if actorReq, ok := request.(TokenActorRequest); ok {
+ claims.Actor = actorReq.GetActor()
+ }
+
+ scopes := client.RestrictAdditionalIdTokenScopes()(request.GetScopes())
+ signingKey, err := storage.SigningKey(ctx)
+ if err != nil {
+ return "", err
}
if accessToken != "" {
- claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
+ atHash, err := oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
if err != nil {
return "", err
}
- }
- if code != "" {
- claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
- if err != nil {
- return "", err
+ claims.AccessTokenHash = atHash
+ if !client.IDTokenUserinfoClaimsAssertion() {
+ scopes = removeUserinfoScopes(scopes)
}
}
- return signer.SignIDToken(claims)
+ tokenExchangeRequest, okReq := request.(TokenExchangeRequest)
+ teStorage, okStorage := storage.(TokenExchangeStorage)
+ if okReq && okStorage {
+ userInfo := new(oidc.UserInfo)
+ err := teStorage.SetUserinfoFromTokenExchangeRequest(ctx, userInfo, tokenExchangeRequest)
+ if err != nil {
+ return "", err
+ }
+ claims.SetUserInfo(userInfo)
+ } else if len(scopes) > 0 {
+ userInfo := new(oidc.UserInfo)
+ err := storage.SetUserinfoFromScopes(ctx, userInfo, request.GetSubject(), request.GetClientID(), scopes)
+ if err != nil {
+ return "", err
+ }
+ if fromRequest, ok := storage.(CanSetUserinfoFromRequest); ok {
+ err := fromRequest.SetUserinfoFromRequest(ctx, userInfo, request, scopes)
+ if err != nil {
+ return "", err
+ }
+ }
+ claims.SetUserInfo(userInfo)
+ }
+ if code != "" {
+ codeHash, err := oidc.ClaimHash(code, signingKey.SignatureAlgorithm())
+ if err != nil {
+ return "", err
+ }
+ claims.CodeHash = codeHash
+ }
+ signer, err := SignerFromKey(signingKey)
+ if err != nil {
+ return "", err
+ }
+ return crypto.Sign(claims, signer)
+}
+
+func removeUserinfoScopes(scopes []string) []string {
+ newScopeList := make([]string, 0, len(scopes))
+ for _, scope := range scopes {
+ switch scope {
+ case oidc.ScopeProfile,
+ oidc.ScopeEmail,
+ oidc.ScopeAddress,
+ oidc.ScopePhone:
+ continue
+ default:
+ newScopeList = append(newScopeList, scope)
+ }
+ }
+ return newScopeList
}
diff --git a/pkg/op/token_client_credentials.go b/pkg/op/token_client_credentials.go
new file mode 100644
index 0000000..ddb2fbf
--- /dev/null
+++ b/pkg/op/token_client_credentials.go
@@ -0,0 +1,125 @@
+package op
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// ClientCredentialsExchange handles the OAuth 2.0 client_credentials grant, including
+// parsing, validating, authorizing the client and finally returning a token
+func ClientCredentialsExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "ClientCredentialsExchange")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ request, err := ParseClientCredentialsRequest(r, exchanger.Decoder())
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+
+ validatedRequest, client, err := ValidateClientCredentialsRequest(r.Context(), request, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+
+ resp, err := CreateClientCredentialsTokenResponse(r.Context(), validatedRequest, exchanger, client)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+
+ httphelper.MarshalJSON(w, resp)
+}
+
+// ParseClientCredentialsRequest parsed the http request into a oidc.ClientCredentialsRequest
+func ParseClientCredentialsRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.ClientCredentialsRequest, error) {
+ err := r.ParseForm()
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+
+ request := new(oidc.ClientCredentialsRequest)
+ err = decoder.Decode(request, r.Form)
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+
+ if clientID, clientSecret, ok := r.BasicAuth(); ok {
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+
+ request.ClientID = clientID
+ request.ClientSecret = clientSecret
+ }
+
+ return request, nil
+}
+
+// ValidateClientCredentialsRequest validates the client_credentials request parameters including authorization check of the client
+// and returns a TokenRequest and Client implementation to be used in the client_credentials response, resp. creation of the corresponding access_token.
+func ValidateClientCredentialsRequest(ctx context.Context, request *oidc.ClientCredentialsRequest, exchanger Exchanger) (TokenRequest, Client, error) {
+ ctx, span := tracer.Start(ctx, "ValidateClientCredentialsRequest")
+ defer span.End()
+
+ storage, ok := exchanger.Storage().(ClientCredentialsStorage)
+ if !ok {
+ return nil, nil, oidc.ErrUnsupportedGrantType().WithDescription("client_credentials grant not supported")
+ }
+
+ client, err := AuthorizeClientCredentialsClient(ctx, request, storage)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ tokenRequest, err := storage.ClientCredentialsTokenRequest(ctx, request.ClientID, request.Scope)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return tokenRequest, client, nil
+}
+
+func AuthorizeClientCredentialsClient(ctx context.Context, request *oidc.ClientCredentialsRequest, storage ClientCredentialsStorage) (Client, error) {
+ ctx, span := tracer.Start(ctx, "AuthorizeClientCredentialsClient")
+ defer span.End()
+
+ client, err := storage.ClientCredentials(ctx, request.ClientID, request.ClientSecret)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithParent(err)
+ }
+
+ if !ValidateGrantType(client, oidc.GrantTypeClientCredentials) {
+ return nil, oidc.ErrUnauthorizedClient()
+ }
+
+ return client, nil
+}
+
+func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
+ ctx, span := tracer.Start(ctx, "CreateClientCredentialsTokenResponse")
+ defer span.End()
+
+ accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
+ if err != nil {
+ return nil, err
+ }
+
+ return &oidc.AccessTokenResponse{
+ AccessToken: accessToken,
+ TokenType: oidc.BearerToken,
+ ExpiresIn: uint64(validity.Seconds()),
+ Scope: tokenRequest.GetScopes(),
+ }, nil
+}
diff --git a/pkg/op/token_code.go b/pkg/op/token_code.go
new file mode 100644
index 0000000..155aa43
--- /dev/null
+++ b/pkg/op/token_code.go
@@ -0,0 +1,125 @@
+package op
+
+import (
+ "context"
+ "net/http"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// CodeExchange handles the OAuth 2.0 authorization_code grant, including
+// parsing, validating, authorizing the client and finally exchanging the code for tokens
+func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "CodeExchange")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+ if tokenReq.Code == "" {
+ RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("code missing"), exchanger.Logger())
+ return
+ }
+ authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code, "")
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ httphelper.MarshalJSON(w, resp)
+}
+
+// ParseAccessTokenRequest parsed the http request into a oidc.AccessTokenRequest
+func ParseAccessTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.AccessTokenRequest, error) {
+ request := new(oidc.AccessTokenRequest)
+ err := ParseAuthenticatedTokenRequest(r, decoder, request)
+ if err != nil {
+ return nil, err
+ }
+ return request, nil
+}
+
+// ValidateAccessTokenRequest validates the token request parameters including authorization check of the client
+// and returns the previous created auth request corresponding to the auth code
+func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
+ ctx, span := tracer.Start(ctx, "ValidateAccessTokenRequest")
+ defer span.End()
+
+ authReq, client, err := AuthorizeCodeClient(ctx, tokenReq, exchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+ if client.GetID() != authReq.GetClientID() {
+ return nil, nil, oidc.ErrInvalidGrant()
+ }
+ if !ValidateGrantType(client, oidc.GrantTypeCode) {
+ return nil, nil, oidc.ErrUnauthorizedClient().WithDescription("client missing grant type " + string(oidc.GrantTypeCode))
+ }
+ if tokenReq.RedirectURI != authReq.GetRedirectURI() {
+ return nil, nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
+ }
+ return authReq, client, nil
+}
+
+// AuthorizeCodeClient checks the authorization of the client and that the used method was the one previously registered.
+// It than returns the auth request corresponding to the auth code
+func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (request AuthRequest, client Client, err error) {
+ ctx, span := tracer.Start(ctx, "AuthorizeCodeClient")
+ defer span.End()
+
+ if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
+ jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
+ if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
+ return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
+ }
+ client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+ request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
+ return request, client, err
+ }
+ client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
+ if err != nil {
+ return nil, nil, oidc.ErrInvalidClient().WithParent(err)
+ }
+ if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
+ return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
+ }
+ if client.AuthMethod() == oidc.AuthMethodNone {
+ request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
+ if err != nil {
+ return nil, nil, err
+ }
+ err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
+ return request, client, err
+ }
+ if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
+ return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
+ }
+ err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
+ if err != nil {
+ return nil, nil, err
+ }
+ request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
+ return request, client, err
+}
+
+// AuthRequestByCode returns the AuthRequest previously created from Storage corresponding to the auth code or an error
+func AuthRequestByCode(ctx context.Context, storage Storage, code string) (AuthRequest, error) {
+ ctx, span := tracer.Start(ctx, "AuthRequestByCode")
+ defer span.End()
+
+ authReq, err := storage.AuthRequestByCode(ctx, code)
+ if err != nil {
+ return nil, oidc.ErrInvalidGrant().WithDescription("invalid code").WithParent(err)
+ }
+ return authReq, nil
+}
diff --git a/pkg/op/token_exchange.go b/pkg/op/token_exchange.go
new file mode 100644
index 0000000..00af485
--- /dev/null
+++ b/pkg/op/token_exchange.go
@@ -0,0 +1,432 @@
+package op
+
+import (
+ "context"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type TokenExchangeRequest interface {
+ GetAMR() []string
+ GetAudience() []string
+ GetResourses() []string
+ GetAuthTime() time.Time
+ GetClientID() string
+ GetScopes() []string
+ GetSubject() string
+ GetRequestedTokenType() oidc.TokenType
+
+ GetExchangeSubject() string
+ GetExchangeSubjectTokenType() oidc.TokenType
+ GetExchangeSubjectTokenIDOrToken() string
+ GetExchangeSubjectTokenClaims() map[string]any
+
+ GetExchangeActor() string
+ GetExchangeActorTokenType() oidc.TokenType
+ GetExchangeActorTokenIDOrToken() string
+ GetExchangeActorTokenClaims() map[string]any
+
+ SetCurrentScopes(scopes []string)
+ SetRequestedTokenType(tt oidc.TokenType)
+ SetSubject(subject string)
+}
+
+type tokenExchangeRequest struct {
+ exchangeSubjectTokenIDOrToken string
+ exchangeSubjectTokenType oidc.TokenType
+ exchangeSubject string
+ exchangeSubjectTokenClaims map[string]any
+
+ exchangeActorTokenIDOrToken string
+ exchangeActorTokenType oidc.TokenType
+ exchangeActor string
+ exchangeActorTokenClaims map[string]any
+
+ resource []string
+ audience oidc.Audience
+ scopes oidc.SpaceDelimitedArray
+ requestedTokenType oidc.TokenType
+ clientID string
+ authTime time.Time
+ subject string
+}
+
+func (r *tokenExchangeRequest) GetAMR() []string {
+ return []string{}
+}
+
+func (r *tokenExchangeRequest) GetAudience() []string {
+ return r.audience
+}
+
+func (r *tokenExchangeRequest) GetResourses() []string {
+ return r.resource
+}
+
+func (r *tokenExchangeRequest) GetAuthTime() time.Time {
+ return r.authTime
+}
+
+func (r *tokenExchangeRequest) GetClientID() string {
+ return r.clientID
+}
+
+func (r *tokenExchangeRequest) GetScopes() []string {
+ return r.scopes
+}
+
+func (r *tokenExchangeRequest) GetRequestedTokenType() oidc.TokenType {
+ return r.requestedTokenType
+}
+
+func (r *tokenExchangeRequest) GetExchangeSubject() string {
+ return r.exchangeSubject
+}
+
+func (r *tokenExchangeRequest) GetExchangeSubjectTokenType() oidc.TokenType {
+ return r.exchangeSubjectTokenType
+}
+
+func (r *tokenExchangeRequest) GetExchangeSubjectTokenIDOrToken() string {
+ return r.exchangeSubjectTokenIDOrToken
+}
+
+func (r *tokenExchangeRequest) GetExchangeSubjectTokenClaims() map[string]any {
+ return r.exchangeSubjectTokenClaims
+}
+
+func (r *tokenExchangeRequest) GetExchangeActor() string {
+ return r.exchangeActor
+}
+
+func (r *tokenExchangeRequest) GetExchangeActorTokenType() oidc.TokenType {
+ return r.exchangeActorTokenType
+}
+
+func (r *tokenExchangeRequest) GetExchangeActorTokenIDOrToken() string {
+ return r.exchangeActorTokenIDOrToken
+}
+
+func (r *tokenExchangeRequest) GetExchangeActorTokenClaims() map[string]any {
+ return r.exchangeActorTokenClaims
+}
+
+func (r *tokenExchangeRequest) GetSubject() string {
+ return r.subject
+}
+
+func (r *tokenExchangeRequest) SetCurrentScopes(scopes []string) {
+ r.scopes = scopes
+}
+
+func (r *tokenExchangeRequest) SetRequestedTokenType(tt oidc.TokenType) {
+ r.requestedTokenType = tt
+}
+
+func (r *tokenExchangeRequest) SetSubject(subject string) {
+ r.subject = subject
+}
+
+// TokenExchange handles the OAuth 2.0 token exchange grant ("urn:ietf:params:oauth:grant-type:token-exchange")
+func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "TokenExchange")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ tokenExchangeReq, clientID, clientSecret, err := ParseTokenExchangeRequest(r, exchanger.Decoder())
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+
+ tokenExchangeRequest, client, err := ValidateTokenExchangeRequest(r.Context(), tokenExchangeReq, clientID, clientSecret, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ resp, err := CreateTokenExchangeResponse(r.Context(), tokenExchangeRequest, client, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ httphelper.MarshalJSON(w, resp)
+}
+
+// ParseTokenExchangeRequest parses the http request into oidc.TokenExchangeRequest
+func ParseTokenExchangeRequest(r *http.Request, decoder httphelper.Decoder) (_ *oidc.TokenExchangeRequest, clientID, clientSecret string, err error) {
+ err = r.ParseForm()
+ if err != nil {
+ return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+
+ request := new(oidc.TokenExchangeRequest)
+ err = decoder.Decode(request, r.Form)
+ if err != nil {
+ return nil, "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+
+ var ok bool
+ if clientID, clientSecret, ok = r.BasicAuth(); ok {
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return nil, "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ }
+
+ return request, clientID, clientSecret, nil
+}
+
+// ValidateTokenExchangeRequest validates the token exchange request parameters including authorization check of the client,
+// subject_token and actor_token
+func ValidateTokenExchangeRequest(
+ ctx context.Context,
+ oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
+ clientID, clientSecret string,
+ exchanger Exchanger,
+) (TokenExchangeRequest, Client, error) {
+ ctx, span := tracer.Start(ctx, "ValidateTokenExchangeRequest")
+ defer span.End()
+
+ if oidcTokenExchangeRequest.SubjectToken == "" {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token missing")
+ }
+
+ if oidcTokenExchangeRequest.SubjectTokenType == "" {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type missing")
+ }
+
+ client, err := AuthorizeTokenExchangeClient(ctx, clientID, clientSecret, exchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if oidcTokenExchangeRequest.RequestedTokenType != "" && !oidcTokenExchangeRequest.RequestedTokenType.IsSupported() {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("requested_token_type is not supported")
+ }
+
+ if !oidcTokenExchangeRequest.SubjectTokenType.IsSupported() {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("subject_token_type is not supported")
+ }
+
+ if oidcTokenExchangeRequest.ActorTokenType != "" && !oidcTokenExchangeRequest.ActorTokenType.IsSupported() {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("actor_token_type is not supported")
+ }
+
+ req, err := CreateTokenExchangeRequest(ctx, oidcTokenExchangeRequest, client, exchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+ return req, client, nil
+}
+
+func CreateTokenExchangeRequest(
+ ctx context.Context,
+ oidcTokenExchangeRequest *oidc.TokenExchangeRequest,
+ client Client,
+ exchanger Exchanger,
+) (TokenExchangeRequest, error) {
+ ctx, span := tracer.Start(ctx, "CreateTokenExchangeRequest")
+ defer span.End()
+
+ teStorage, ok := exchanger.Storage().(TokenExchangeStorage)
+ if !ok {
+ return nil, unimplementedGrantError(oidc.GrantTypeTokenExchange)
+ }
+
+ exchangeSubjectTokenIDOrToken, exchangeSubject, exchangeSubjectTokenClaims, ok := GetTokenIDAndSubjectFromToken(ctx, exchanger,
+ oidcTokenExchangeRequest.SubjectToken, oidcTokenExchangeRequest.SubjectTokenType, false)
+ if !ok {
+ return nil, oidc.ErrInvalidRequest().WithDescription("subject_token is invalid")
+ }
+
+ var (
+ exchangeActorTokenIDOrToken, exchangeActor string
+ exchangeActorTokenClaims map[string]any
+ )
+ if oidcTokenExchangeRequest.ActorToken != "" {
+ exchangeActorTokenIDOrToken, exchangeActor, exchangeActorTokenClaims, ok = GetTokenIDAndSubjectFromToken(ctx, exchanger,
+ oidcTokenExchangeRequest.ActorToken, oidcTokenExchangeRequest.ActorTokenType, true)
+ if !ok {
+ return nil, oidc.ErrInvalidRequest().WithDescription("actor_token is invalid")
+ }
+ }
+
+ req := &tokenExchangeRequest{
+ exchangeSubjectTokenIDOrToken: exchangeSubjectTokenIDOrToken,
+ exchangeSubjectTokenType: oidcTokenExchangeRequest.SubjectTokenType,
+ exchangeSubject: exchangeSubject,
+ exchangeSubjectTokenClaims: exchangeSubjectTokenClaims,
+
+ exchangeActorTokenIDOrToken: exchangeActorTokenIDOrToken,
+ exchangeActorTokenType: oidcTokenExchangeRequest.ActorTokenType,
+ exchangeActor: exchangeActor,
+ exchangeActorTokenClaims: exchangeActorTokenClaims,
+
+ subject: exchangeSubject,
+ resource: oidcTokenExchangeRequest.Resource,
+ audience: oidcTokenExchangeRequest.Audience,
+ scopes: oidcTokenExchangeRequest.Scopes,
+ requestedTokenType: oidcTokenExchangeRequest.RequestedTokenType,
+ clientID: client.GetID(),
+ authTime: time.Now(),
+ }
+
+ err := teStorage.ValidateTokenExchangeRequest(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ err = teStorage.CreateTokenExchangeRequest(ctx, req)
+ if err != nil {
+ return nil, err
+ }
+
+ return req, nil
+}
+
+func GetTokenIDAndSubjectFromToken(
+ ctx context.Context,
+ exchanger Exchanger,
+ token string,
+ tokenType oidc.TokenType,
+ isActor bool,
+) (tokenIDOrToken, subject string, claims map[string]any, ok bool) {
+ ctx, span := tracer.Start(ctx, "GetTokenIDAndSubjectFromToken")
+ defer span.End()
+
+ switch tokenType {
+ case oidc.AccessTokenType:
+ var accessTokenClaims *oidc.AccessTokenClaims
+ tokenIDOrToken, subject, accessTokenClaims, ok = getTokenIDAndClaims(ctx, exchanger, token)
+ if !ok {
+ break
+ }
+ claims = accessTokenClaims.Claims
+ case oidc.RefreshTokenType:
+ refreshTokenRequest, err := exchanger.Storage().TokenRequestByRefreshToken(ctx, token)
+ if err != nil {
+ break
+ }
+
+ tokenIDOrToken, subject, ok = token, refreshTokenRequest.GetSubject(), true
+ case oidc.IDTokenType:
+ idTokenClaims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, token, exchanger.IDTokenHintVerifier(ctx))
+ if err != nil {
+ break
+ }
+
+ tokenIDOrToken, subject, claims, ok = token, idTokenClaims.Subject, idTokenClaims.Claims, true
+ }
+
+ if !ok {
+ if verifier, ok := exchanger.Storage().(TokenExchangeTokensVerifierStorage); ok {
+ var err error
+ if isActor {
+ tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeActorToken(ctx, token, tokenType)
+ } else {
+ tokenIDOrToken, subject, claims, err = verifier.VerifyExchangeSubjectToken(ctx, token, tokenType)
+ }
+ if err != nil {
+ return "", "", nil, false
+ }
+
+ return tokenIDOrToken, subject, claims, true
+ }
+
+ return "", "", nil, false
+ }
+
+ return tokenIDOrToken, subject, claims, true
+}
+
+// AuthorizeTokenExchangeClient authorizes a client by validating the client_id and client_secret
+func AuthorizeTokenExchangeClient(ctx context.Context, clientID, clientSecret string, exchanger Exchanger) (client Client, err error) {
+ ctx, span := tracer.Start(ctx, "AuthorizeTokenExchangeClient")
+ defer span.End()
+
+ if err := AuthorizeClientIDSecret(ctx, clientID, clientSecret, exchanger.Storage()); err != nil {
+ return nil, err
+ }
+
+ client, err = exchanger.Storage().GetClientByClientID(ctx, clientID)
+ if err != nil {
+ return nil, oidc.ErrInvalidClient().WithParent(err)
+ }
+
+ return client, nil
+}
+
+func CreateTokenExchangeResponse(
+ ctx context.Context,
+ tokenExchangeRequest TokenExchangeRequest,
+ client Client,
+ creator TokenCreator,
+) (_ *oidc.TokenExchangeResponse, err error) {
+ ctx, span := tracer.Start(ctx, "CreateTokenExchangeResponse")
+ defer span.End()
+
+ var (
+ token, refreshToken, tokenType string
+ validity time.Duration
+ )
+
+ switch tokenExchangeRequest.GetRequestedTokenType() {
+ case oidc.AccessTokenType, oidc.RefreshTokenType:
+ token, refreshToken, validity, err = CreateAccessToken(ctx, tokenExchangeRequest, client.AccessTokenType(), creator, client, "")
+ if err != nil {
+ return nil, err
+ }
+
+ tokenType = oidc.BearerToken
+ case oidc.IDTokenType:
+ token, err = CreateIDToken(ctx, IssuerFromContext(ctx), tokenExchangeRequest, client.IDTokenLifetime(), "", "", creator.Storage(), client)
+ if err != nil {
+ return nil, err
+ }
+
+ // not applicable (see https://datatracker.ietf.org/doc/html/rfc8693#section-2-2-1-2-6)
+ tokenType = "N_A"
+ default:
+ // oidc.JWTTokenType and other custom token types are not supported for issuing.
+ // In the future it can be considered to have custom tokens generation logic injected via op configuration
+ // or via expanding Storage interface
+ oidc.ErrInvalidRequest().WithDescription("requested_token_type is invalid")
+ }
+
+ exp := uint64(validity.Seconds())
+ return &oidc.TokenExchangeResponse{
+ AccessToken: token,
+ IssuedTokenType: tokenExchangeRequest.GetRequestedTokenType(),
+ TokenType: tokenType,
+ ExpiresIn: exp,
+ RefreshToken: refreshToken,
+ Scopes: tokenExchangeRequest.GetScopes(),
+ }, nil
+}
+
+func getTokenIDAndClaims(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, *oidc.AccessTokenClaims, bool) {
+ tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
+ if err == nil {
+ splitToken := strings.Split(tokenIDSubject, ":")
+ if len(splitToken) != 2 {
+ return "", "", nil, false
+ }
+
+ return splitToken[0], splitToken[1], nil, true
+ }
+ accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
+ if err != nil {
+ return "", "", nil, false
+ }
+
+ return accessTokenClaims.JWTID, accessTokenClaims.Subject, accessTokenClaims, true
+}
diff --git a/pkg/op/token_intospection.go b/pkg/op/token_intospection.go
new file mode 100644
index 0000000..bb6a5a0
--- /dev/null
+++ b/pkg/op/token_intospection.go
@@ -0,0 +1,76 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "net/http"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type Introspector interface {
+ Decoder() httphelper.Decoder
+ Crypto() Crypto
+ Storage() Storage
+ AccessTokenVerifier(context.Context) *AccessTokenVerifier
+}
+
+type IntrospectorJWTProfile interface {
+ Introspector
+ JWTProfileVerifier(context.Context) JWTProfileVerifier
+}
+
+func introspectionHandler(introspector Introspector) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Introspect(w, r, introspector)
+ }
+}
+
+func Introspect(w http.ResponseWriter, r *http.Request, introspector Introspector) {
+ ctx, span := tracer.Start(r.Context(), "Introspect")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ response := new(oidc.IntrospectionResponse)
+ token, clientID, err := ParseTokenIntrospectionRequest(r, introspector)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusUnauthorized)
+ return
+ }
+ tokenID, subject, ok := getTokenIDAndSubject(r.Context(), introspector, token)
+ if !ok {
+ httphelper.MarshalJSON(w, response)
+ return
+ }
+ err = introspector.Storage().SetIntrospectionFromToken(r.Context(), response, tokenID, subject, clientID)
+ if err != nil {
+ httphelper.MarshalJSON(w, response)
+ return
+ }
+ response.Active = true
+ httphelper.MarshalJSON(w, response)
+}
+
+func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector) (token, clientID string, err error) {
+ clientID, authenticated, err := ClientIDFromRequest(r, introspector)
+ if err != nil {
+ return "", "", err
+ }
+ if !authenticated {
+ return "", "", oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials)
+ }
+
+ req := new(oidc.IntrospectionRequest)
+ err = introspector.Decoder().Decode(req, r.Form)
+ if err != nil {
+ return "", "", errors.New("unable to parse request")
+ }
+
+ return req.Token, clientID, nil
+}
+
+type IntrospectionRequest struct {
+ *ClientCredentials
+ *oidc.IntrospectionRequest
+}
diff --git a/pkg/op/token_jwt_profile.go b/pkg/op/token_jwt_profile.go
new file mode 100644
index 0000000..defb937
--- /dev/null
+++ b/pkg/op/token_jwt_profile.go
@@ -0,0 +1,125 @@
+package op
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type JWTAuthorizationGrantExchanger interface {
+ Exchanger
+ JWTProfileVerifier(context.Context) *JWTProfileVerifier
+}
+
+// JWTProfile handles the OAuth 2.0 JWT Profile Authorization Grant https://tools.ietf.org/html/rfc7523#section-2.1
+func JWTProfile(w http.ResponseWriter, r *http.Request, exchanger JWTAuthorizationGrantExchanger) {
+ ctx, span := tracer.Start(r.Context(), "JWTProfile")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ profileRequest, err := ParseJWTProfileGrantRequest(r, exchanger.Decoder())
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+
+ tokenRequest, err := VerifyJWTAssertion(r.Context(), profileRequest.Assertion, exchanger.JWTProfileVerifier(r.Context()))
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+
+ tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(r.Context(), tokenRequest.Issuer, profileRequest.Scope)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ resp, err := CreateJWTTokenResponse(r.Context(), tokenRequest, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ httphelper.MarshalJSON(w, resp)
+}
+
+func ParseJWTProfileGrantRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
+ err := r.ParseForm()
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+ tokenReq := new(oidc.JWTProfileGrantRequest)
+ err = decoder.Decode(tokenReq, r.Form)
+ if err != nil {
+ return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+ return tokenReq, nil
+}
+
+// CreateJWTTokenResponse creates an access_token response for a JWT Profile Grant request
+// by default the access_token is an opaque string, but can be specified by implementing the JWTProfileTokenStorage interface
+func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator) (*oidc.AccessTokenResponse, error) {
+ ctx, span := tracer.Start(ctx, "CreateJWTTokenResponse")
+ defer span.End()
+
+ // return an opaque token as default to not break current implementations
+ tokenType := AccessTokenTypeBearer
+
+ // the current CreateAccessToken function, esp. CreateJWT requires an implementation of an AccessTokenClient
+ client := &jwtProfileClient{
+ id: tokenRequest.GetSubject(),
+ }
+
+ // by implementing the JWTProfileTokenStorage the storage can specify the AccessTokenType to be returned
+ tokenStorage, ok := creator.Storage().(JWTProfileTokenStorage)
+ if ok {
+ var err error
+ tokenType, err = tokenStorage.JWTProfileTokenType(ctx, tokenRequest)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ accessToken, _, validity, err := CreateAccessToken(ctx, tokenRequest, tokenType, creator, client, "")
+ if err != nil {
+ return nil, err
+ }
+ return &oidc.AccessTokenResponse{
+ AccessToken: accessToken,
+ TokenType: oidc.BearerToken,
+ ExpiresIn: uint64(validity.Seconds()),
+ Scope: tokenRequest.GetScopes(),
+ }, nil
+}
+
+// ParseJWTProfileRequest has been renamed to ParseJWTProfileGrantRequest
+//
+// deprecated: use ParseJWTProfileGrantRequest
+func ParseJWTProfileRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.JWTProfileGrantRequest, error) {
+ return ParseJWTProfileGrantRequest(r, decoder)
+}
+
+type jwtProfileClient struct {
+ id string
+}
+
+func (j *jwtProfileClient) GetID() string {
+ return j.id
+}
+
+func (j *jwtProfileClient) ClockSkew() time.Duration {
+ return 0
+}
+
+func (j *jwtProfileClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
+ return func(scopes []string) []string {
+ return scopes
+ }
+}
+
+func (j *jwtProfileClient) GrantTypes() []oidc.GrantType {
+ return []oidc.GrantType{
+ oidc.GrantTypeBearer,
+ }
+}
diff --git a/pkg/op/token_refresh.go b/pkg/op/token_refresh.go
new file mode 100644
index 0000000..a87e883
--- /dev/null
+++ b/pkg/op/token_refresh.go
@@ -0,0 +1,152 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "slices"
+ "time"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type RefreshTokenRequest interface {
+ GetAMR() []string
+ GetAudience() []string
+ GetAuthTime() time.Time
+ GetClientID() string
+ GetScopes() []string
+ GetSubject() string
+ SetCurrentScopes(scopes []string)
+}
+
+// RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including
+// parsing, validating, authorizing the client and finally exchanging the refresh_token for new tokens
+func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "RefreshTokenExchange")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder())
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ }
+ validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken)
+ if err != nil {
+ RequestError(w, r, err, exchanger.Logger())
+ return
+ }
+ httphelper.MarshalJSON(w, resp)
+}
+
+// ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest
+func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.RefreshTokenRequest, error) {
+ request := new(oidc.RefreshTokenRequest)
+ err := ParseAuthenticatedTokenRequest(r, decoder, request)
+ if err != nil {
+ return nil, err
+ }
+ return request, nil
+}
+
+// ValidateRefreshTokenRequest validates the refresh_token request parameters including authorization check of the client
+// and returns the data representing the original auth request corresponding to the refresh_token
+func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) {
+ ctx, span := tracer.Start(ctx, "ValidateRefreshTokenRequest")
+ defer span.End()
+
+ if tokenReq.RefreshToken == "" {
+ return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing")
+ }
+ request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+ if client.GetID() != request.GetClientID() {
+ return nil, nil, oidc.ErrInvalidGrant()
+ }
+ if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil {
+ return nil, nil, err
+ }
+ return request, client, nil
+}
+
+// ValidateRefreshTokenScopes validates that the requested scope is a subset of the original auth request scope
+// it will set the requested scopes as current scopes onto RefreshTokenRequest
+// if empty the original scopes will be used
+func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTokenRequest) error {
+ if len(requestedScopes) == 0 {
+ return nil
+ }
+ for _, scope := range requestedScopes {
+ if !slices.Contains(authRequest.GetScopes(), scope) {
+ return oidc.ErrInvalidScope()
+ }
+ }
+ authRequest.SetCurrentScopes(requestedScopes)
+ return nil
+}
+
+// AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered.
+// It than returns the data representing the original auth request corresponding to the refresh_token
+func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) {
+ ctx, span := tracer.Start(ctx, "AuthorizeRefreshClient")
+ defer span.End()
+
+ if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
+ jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
+ if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
+ return nil, nil, errors.New("auth_method private_key_jwt not supported")
+ }
+ client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
+ if err != nil {
+ return nil, nil, err
+ }
+ if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
+ return nil, nil, oidc.ErrUnauthorizedClient()
+ }
+ request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
+ return request, client, err
+ }
+ client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
+ if err != nil {
+ return nil, nil, err
+ }
+ if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
+ return nil, nil, oidc.ErrUnauthorizedClient()
+ }
+ if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
+ return nil, nil, oidc.ErrInvalidClient()
+ }
+ if client.AuthMethod() == oidc.AuthMethodNone {
+ request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
+ return request, client, err
+ }
+ if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
+ return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
+ }
+ if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil {
+ return nil, nil, err
+ }
+ request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
+ return request, client, err
+}
+
+// RefreshTokenRequestByRefreshToken returns the RefreshTokenRequest (data representing the original auth request)
+// corresponding to the refresh_token from Storage or an error
+func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) {
+ ctx, span := tracer.Start(ctx, "RefreshTokenRequestByRefreshToken")
+ defer span.End()
+
+ request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken)
+ if err != nil {
+ return nil, oidc.ErrInvalidGrant().WithParent(err)
+ }
+ return request, nil
+}
diff --git a/pkg/op/token_request.go b/pkg/op/token_request.go
new file mode 100644
index 0000000..3f5af7a
--- /dev/null
+++ b/pkg/op/token_request.go
@@ -0,0 +1,183 @@
+package op
+
+import (
+ "context"
+ "log/slog"
+ "net/http"
+ "net/url"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type Exchanger interface {
+ Storage() Storage
+ Decoder() httphelper.Decoder
+ Crypto() Crypto
+ AuthMethodPostSupported() bool
+ AuthMethodPrivateKeyJWTSupported() bool
+ GrantTypeRefreshTokenSupported() bool
+ GrantTypeTokenExchangeSupported() bool
+ GrantTypeJWTAuthorizationSupported() bool
+ GrantTypeClientCredentialsSupported() bool
+ GrantTypeDeviceCodeSupported() bool
+ AccessTokenVerifier(context.Context) *AccessTokenVerifier
+ IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
+ Logger() *slog.Logger
+}
+
+func tokenHandler(exchanger Exchanger) func(w http.ResponseWriter, r *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx, span := tracer.Start(r.Context(), "tokenHandler")
+ defer span.End()
+
+ Exchange(w, r.WithContext(ctx), exchanger)
+ }
+}
+
+// Exchange performs a token exchange appropriate for the grant type
+func Exchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
+ ctx, span := tracer.Start(r.Context(), "Exchange")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ grantType := r.FormValue("grant_type")
+ switch grantType {
+ case string(oidc.GrantTypeCode):
+ CodeExchange(w, r, exchanger)
+ return
+ case string(oidc.GrantTypeRefreshToken):
+ if exchanger.GrantTypeRefreshTokenSupported() {
+ RefreshTokenExchange(w, r, exchanger)
+ return
+ }
+ case string(oidc.GrantTypeBearer):
+ if ex, ok := exchanger.(JWTAuthorizationGrantExchanger); ok && exchanger.GrantTypeJWTAuthorizationSupported() {
+ JWTProfile(w, r, ex)
+ return
+ }
+ case string(oidc.GrantTypeTokenExchange):
+ if exchanger.GrantTypeTokenExchangeSupported() {
+ TokenExchange(w, r, exchanger)
+ return
+ }
+ case string(oidc.GrantTypeClientCredentials):
+ if exchanger.GrantTypeClientCredentialsSupported() {
+ ClientCredentialsExchange(w, r, exchanger)
+ return
+ }
+ case string(oidc.GrantTypeDeviceCode):
+ if exchanger.GrantTypeDeviceCodeSupported() {
+ DeviceAccessToken(w, r, exchanger)
+ return
+ }
+ case "":
+ RequestError(w, r, oidc.ErrInvalidRequest().WithDescription("grant_type missing"), exchanger.Logger())
+ return
+ }
+ RequestError(w, r, oidc.ErrUnsupportedGrantType().WithDescription("%s not supported", grantType), exchanger.Logger())
+}
+
+// AuthenticatedTokenRequest is a helper interface for ParseAuthenticatedTokenRequest
+// it is implemented by oidc.AuthRequest and oidc.RefreshTokenRequest
+type AuthenticatedTokenRequest interface {
+ SetClientID(string)
+ SetClientSecret(string)
+}
+
+// ParseAuthenticatedTokenRequest parses the client_id and client_secret from the HTTP request from either
+// HTTP Basic Auth header or form body and sets them into the provided authenticatedTokenRequest interface
+func ParseAuthenticatedTokenRequest(r *http.Request, decoder httphelper.Decoder, request AuthenticatedTokenRequest) error {
+ ctx, span := tracer.Start(r.Context(), "ParseAuthenticatedTokenRequest")
+ defer span.End()
+ r = r.WithContext(ctx)
+
+ err := r.ParseForm()
+ if err != nil {
+ return oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
+ }
+ err = decoder.Decode(request, r.Form)
+ if err != nil {
+ return oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+ clientID, clientSecret, ok := r.BasicAuth()
+ if !ok {
+ return nil
+ }
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ request.SetClientID(clientID)
+ request.SetClientSecret(clientSecret)
+ return nil
+}
+
+// AuthorizeClientIDSecret authorizes a client by validating the client_id and client_secret (Basic Auth and POST)
+func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage Storage) error {
+ ctx, span := tracer.Start(ctx, "AuthorizeClientIDSecret")
+ defer span.End()
+
+ err := storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
+ if err != nil {
+ return oidc.ErrInvalidClient().WithDescription("invalid client_id / client_secret").WithParent(err)
+ }
+ return nil
+}
+
+// AuthorizeCodeChallenge authorizes a client by validating the code_verifier against the previously sent
+// code_challenge of the auth request (PKCE)
+func AuthorizeCodeChallenge(codeVerifier string, challenge *oidc.CodeChallenge) error {
+ if challenge == nil {
+ if codeVerifier != "" {
+ return oidc.ErrInvalidRequest().WithDescription("code_verifier unexpectedly provided")
+ }
+
+ return nil
+ }
+
+ if codeVerifier == "" {
+ return oidc.ErrInvalidRequest().WithDescription("code_verifier required")
+ }
+ if !oidc.VerifyCodeChallenge(challenge, codeVerifier) {
+ return oidc.ErrInvalidGrant().WithDescription("invalid code_verifier")
+ }
+ return nil
+}
+
+// AuthorizePrivateJWTKey authorizes a client by validating the client_assertion's signature with a previously
+// registered public key (JWT Profile)
+func AuthorizePrivateJWTKey(ctx context.Context, clientAssertion string, exchanger JWTAuthorizationGrantExchanger) (Client, error) {
+ ctx, span := tracer.Start(ctx, "AuthorizePrivateJWTKey")
+ defer span.End()
+
+ jwtReq, err := VerifyJWTAssertion(ctx, clientAssertion, exchanger.JWTProfileVerifier(ctx))
+ if err != nil {
+ return nil, err
+ }
+ client, err := exchanger.Storage().GetClientByClientID(ctx, jwtReq.Issuer)
+ if err != nil {
+ return nil, err
+ }
+ if client.AuthMethod() != oidc.AuthMethodPrivateKeyJWT {
+ return nil, oidc.ErrInvalidClient()
+ }
+ return client, nil
+}
+
+// ValidateGrantType ensures that the requested grant_type is allowed by the client
+func ValidateGrantType(client interface{ GrantTypes() []oidc.GrantType }, grantType oidc.GrantType) bool {
+ if client == nil {
+ return false
+ }
+ for _, grant := range client.GrantTypes() {
+ if grantType == grant {
+ return true
+ }
+ }
+ return false
+}
diff --git a/pkg/op/token_request_test.go b/pkg/op/token_request_test.go
new file mode 100644
index 0000000..d226af6
--- /dev/null
+++ b/pkg/op/token_request_test.go
@@ -0,0 +1,75 @@
+package op_test
+
+import (
+ "testing"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestAuthorizeCodeChallenge(t *testing.T) {
+ tests := []struct {
+ name string
+ codeVerifier string
+ codeChallenge *oidc.CodeChallenge
+ want func(t *testing.T, err error)
+ }{
+ {
+ name: "missing both code_verifier and code_challenge",
+ codeVerifier: "",
+ codeChallenge: nil,
+ want: func(t *testing.T, err error) {
+ assert.Nil(t, err)
+ },
+ },
+ {
+ name: "valid code_verifier",
+ codeVerifier: "Hello World!",
+ codeChallenge: &oidc.CodeChallenge{
+ Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
+ Method: oidc.CodeChallengeMethodS256,
+ },
+ want: func(t *testing.T, err error) {
+ assert.Nil(t, err)
+ },
+ },
+ {
+ name: "invalid code_verifier",
+ codeVerifier: "Hi World!",
+ codeChallenge: &oidc.CodeChallenge{
+ Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
+ Method: oidc.CodeChallengeMethodS256,
+ },
+ want: func(t *testing.T, err error) {
+ assert.ErrorContains(t, err, "invalid code_verifier")
+ },
+ },
+ {
+ name: "code_verifier provided without code_challenge",
+ codeVerifier: "code_verifier",
+ codeChallenge: nil,
+ want: func(t *testing.T, err error) {
+ assert.ErrorContains(t, err, "code_verifier unexpectedly provided")
+ },
+ },
+ {
+ name: "empty code_verifier",
+ codeVerifier: "",
+ codeChallenge: &oidc.CodeChallenge{
+ Challenge: "f4OxZX_x_FO5LcGBSKHWXfwtSx-j1ncoSt3SABJtkGk",
+ Method: oidc.CodeChallengeMethodS256,
+ },
+ want: func(t *testing.T, err error) {
+ assert.ErrorContains(t, err, "code_verifier required")
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := op.AuthorizeCodeChallenge(tt.codeVerifier, tt.codeChallenge)
+
+ tt.want(t, err)
+ })
+ }
+}
diff --git a/pkg/op/token_revocation.go b/pkg/op/token_revocation.go
new file mode 100644
index 0000000..049ee15
--- /dev/null
+++ b/pkg/op/token_revocation.go
@@ -0,0 +1,175 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/url"
+ "strings"
+
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type Revoker interface {
+ Decoder() httphelper.Decoder
+ Crypto() Crypto
+ Storage() Storage
+ AccessTokenVerifier(context.Context) *AccessTokenVerifier
+ AuthMethodPrivateKeyJWTSupported() bool
+ AuthMethodPostSupported() bool
+}
+
+type RevokerJWTProfile interface {
+ Revoker
+ JWTProfileVerifier(context.Context) *JWTProfileVerifier
+}
+
+func revocationHandler(revoker Revoker) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Revoke(w, r, revoker)
+ }
+}
+
+func Revoke(w http.ResponseWriter, r *http.Request, revoker Revoker) {
+ ctx, span := tracer.Start(r.Context(), "Revoke")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ token, tokenTypeHint, clientID, err := ParseTokenRevocationRequest(r, revoker)
+ if err != nil {
+ RevocationRequestError(w, r, err)
+ return
+ }
+ var subject string
+ doDecrypt := true
+ if tokenTypeHint != "access_token" {
+ userID, tokenID, err := revoker.Storage().GetRefreshTokenInfo(r.Context(), clientID, token)
+ if err != nil {
+ // An invalid refresh token means that we'll try other things (leaving doDecrypt==true)
+ if !errors.Is(err, ErrInvalidRefreshToken) {
+ RevocationRequestError(w, r, oidc.ErrServerError().WithParent(err))
+ return
+ }
+ } else {
+ token = tokenID
+ subject = userID
+ doDecrypt = false
+ }
+ }
+ if doDecrypt {
+ tokenID, userID, ok := getTokenIDAndSubjectForRevocation(r.Context(), revoker, token)
+ if ok {
+ token = tokenID
+ subject = userID
+ }
+ }
+ if err := revoker.Storage().RevokeToken(r.Context(), token, subject, clientID); err != nil {
+ RevocationRequestError(w, r, err)
+ return
+ }
+ httphelper.MarshalJSON(w, nil)
+}
+
+func ParseTokenRevocationRequest(r *http.Request, revoker Revoker) (token, tokenTypeHint, clientID string, err error) {
+ ctx, span := tracer.Start(r.Context(), "ParseTokenRevocationRequest")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ err = r.ParseForm()
+ if err != nil {
+ return "", "", "", oidc.ErrInvalidRequest().WithDescription("unable to parse request").WithParent(err)
+ }
+ req := new(struct {
+ oidc.RevocationRequest
+ oidc.ClientAssertionParams // for auth_method private_key_jwt
+ ClientID string `schema:"client_id"` // for auth_method none and post
+ ClientSecret string `schema:"client_secret"` // for auth_method post
+ })
+ err = revoker.Decoder().Decode(req, r.Form)
+ if err != nil {
+ return "", "", "", oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
+ }
+ if req.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
+ revokerJWTProfile, ok := revoker.(RevokerJWTProfile)
+ if !ok || !revoker.AuthMethodPrivateKeyJWTSupported() {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method private_key_jwt not supported")
+ }
+ profile, err := VerifyJWTAssertion(r.Context(), req.ClientAssertion, revokerJWTProfile.JWTProfileVerifier(r.Context()))
+ if err == nil {
+ return req.Token, req.TokenTypeHint, profile.Issuer, nil
+ }
+ return "", "", "", err
+ }
+ clientID, clientSecret, ok := r.BasicAuth()
+ if ok {
+ clientID, err = url.QueryUnescape(clientID)
+ if err != nil {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ clientSecret, err = url.QueryUnescape(clientSecret)
+ if err != nil {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid basic auth header").WithParent(err)
+ }
+ if err = AuthorizeClientIDSecret(r.Context(), clientID, clientSecret, revoker.Storage()); err != nil {
+ return "", "", "", err
+ }
+ return req.Token, req.TokenTypeHint, clientID, nil
+ }
+ if req.ClientID == "" {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
+ }
+ client, err := revoker.Storage().GetClientByClientID(r.Context(), req.ClientID)
+ if err != nil {
+ return "", "", "", oidc.ErrInvalidClient().WithParent(err)
+ }
+ if req.ClientSecret == "" {
+ if client.AuthMethod() != oidc.AuthMethodNone {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("invalid authorization")
+ }
+ return req.Token, req.TokenTypeHint, req.ClientID, nil
+ }
+ if client.AuthMethod() == oidc.AuthMethodPost && !revoker.AuthMethodPostSupported() {
+ return "", "", "", oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
+ }
+ if err = AuthorizeClientIDSecret(r.Context(), req.ClientID, req.ClientSecret, revoker.Storage()); err != nil {
+ return "", "", "", err
+ }
+ return req.Token, req.TokenTypeHint, req.ClientID, nil
+}
+
+func RevocationRequestError(w http.ResponseWriter, r *http.Request, err error) {
+ statusErr := RevocationError(err)
+ httphelper.MarshalJSONWithStatus(w, statusErr.parent, statusErr.statusCode)
+}
+
+func RevocationError(err error) StatusError {
+ e := oidc.DefaultToServerError(err, err.Error())
+ status := http.StatusBadRequest
+ switch e.ErrorType {
+ case oidc.InvalidClient:
+ status = 401
+ case oidc.ServerError:
+ status = 500
+ }
+ return NewStatusError(e, status)
+}
+
+func getTokenIDAndSubjectForRevocation(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
+ ctx, span := tracer.Start(ctx, "getTokenIDAndSubjectForRevocation")
+ defer span.End()
+
+ tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
+ if err == nil {
+ splitToken := strings.Split(tokenIDSubject, ":")
+ if len(splitToken) != 2 {
+ return "", "", false
+ }
+ return splitToken[0], splitToken[1], true
+ }
+ accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
+ if err != nil {
+ return "", "", false
+ }
+ return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
+}
diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go
deleted file mode 100644
index 5ef4b22..0000000
--- a/pkg/op/tokenrequest.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package op
-
-import (
- "context"
- "errors"
- "net/http"
-
- "github.com/gorilla/schema"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/utils"
-)
-
-type Exchanger interface {
- Issuer() string
- Storage() Storage
- Decoder() *schema.Decoder
- Signer() Signer
- Crypto() Crypto
- AuthMethodPostSupported() bool
-}
-
-func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
- tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
- if err != nil {
- RequestError(w, r, err)
- }
- if tokenReq.Code == "" {
- RequestError(w, r, ErrInvalidRequest("code missing"))
- return
- }
- authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
- if err != nil {
- RequestError(w, r, err)
- return
- }
- resp, err := CreateTokenResponse(r.Context(), authReq, client, exchanger, true, tokenReq.Code)
- if err != nil {
- RequestError(w, r, err)
- return
- }
- utils.MarshalJSON(w, resp)
-}
-
-func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
- err := r.ParseForm()
- if err != nil {
- return nil, ErrInvalidRequest("error parsing form")
- }
- tokenReq := new(oidc.AccessTokenRequest)
- err = decoder.Decode(tokenReq, r.Form)
- if err != nil {
- return nil, ErrInvalidRequest("error decoding form")
- }
- clientID, clientSecret, ok := r.BasicAuth()
- if ok {
- tokenReq.ClientID = clientID
- tokenReq.ClientSecret = clientSecret
-
- }
- return tokenReq, nil
-}
-
-func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
- authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
- if err != nil {
- return nil, nil, err
- }
- if client.GetID() != authReq.GetClientID() {
- return nil, nil, ErrInvalidRequest("invalid auth code")
- }
- if tokenReq.RedirectURI != authReq.GetRedirectURI() {
- return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
- }
- return authReq, client, nil
-}
-
-func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
- client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
- if err != nil {
- return nil, nil, err
- }
- if client.GetAuthMethod() == AuthMethodNone {
- authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger)
- return authReq, client, err
- }
- if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
- return nil, nil, errors.New("basic not supported")
- }
- err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
- if err != nil {
- return nil, nil, err
- }
- authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
- if err != nil {
- return nil, nil, ErrInvalidRequest("invalid code")
- }
- return authReq, client, nil
-}
-
-func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
- return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
-}
-
-func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
- if tokenReq.CodeVerifier == "" {
- return nil, ErrInvalidRequest("code_challenge required")
- }
- authReq, err := exchanger.Storage().AuthRequestByCode(ctx, tokenReq.Code)
- if err != nil {
- return nil, ErrInvalidRequest("invalid code")
- }
- if !oidc.VerifyCodeChallenge(authReq.GetCodeChallenge(), tokenReq.CodeVerifier) {
- return nil, ErrInvalidRequest("code_challenge invalid")
- }
- return authReq, nil
-}
-
-func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
- tokenRequest, err := ParseTokenExchangeRequest(w, r)
- if err != nil {
- RequestError(w, r, err)
- return
- }
- err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
- if err != nil {
- RequestError(w, r, err)
- return
- }
-}
-
-func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
- return nil, errors.New("Unimplemented") //TODO: impl
-}
-
-func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
- return errors.New("Unimplemented") //TODO: impl
-}
diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go
index 69746c7..ff75e72 100644
--- a/pkg/op/userinfo.go
+++ b/pkg/op/userinfo.go
@@ -1,50 +1,62 @@
package op
import (
+ "context"
"errors"
"net/http"
"strings"
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/utils"
- "github.com/gorilla/schema"
+ httphelper "git.christmann.info/LARA/zitadel-oidc/v3/pkg/http"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
)
type UserinfoProvider interface {
- Decoder() *schema.Decoder
+ Decoder() httphelper.Decoder
Crypto() Crypto
Storage() Storage
+ AccessTokenVerifier(context.Context) *AccessTokenVerifier
+}
+
+func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ Userinfo(w, r, userinfoProvider)
+ }
}
func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) {
- accessToken, err := getAccessToken(r, userinfoProvider.Decoder())
+ ctx, span := tracer.Start(r.Context(), "Userinfo")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ accessToken, err := ParseUserinfoRequest(r, userinfoProvider.Decoder())
if err != nil {
http.Error(w, "access token missing", http.StatusUnauthorized)
return
}
- tokenID, err := userinfoProvider.Crypto().Decrypt(accessToken)
- if err != nil {
- http.Error(w, "access token missing", http.StatusUnauthorized)
+ tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
+ if !ok {
+ http.Error(w, "access token invalid", http.StatusUnauthorized)
return
}
- info, err := userinfoProvider.Storage().GetUserinfoFromToken(r.Context(), tokenID)
+ info := new(oidc.UserInfo)
+ err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
if err != nil {
- utils.MarshalJSON(w, err)
+ httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
return
}
- utils.MarshalJSON(w, info)
+ httphelper.MarshalJSON(w, info)
}
-func getAccessToken(r *http.Request, decoder *schema.Decoder) (string, error) {
- authHeader := r.Header.Get("authorization")
- if authHeader != "" {
- parts := strings.Split(authHeader, "Bearer ")
- if len(parts) != 2 {
- return "", errors.New("invalid auth header")
- }
- return parts[1], nil
+func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) {
+ ctx, span := tracer.Start(r.Context(), "ParseUserinfoRequest")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ accessToken, err := getAccessToken(r)
+ if err == nil {
+ return accessToken, nil
}
- err := r.ParseForm()
+ err = r.ParseForm()
if err != nil {
return "", errors.New("unable to parse request")
}
@@ -55,3 +67,38 @@ func getAccessToken(r *http.Request, decoder *schema.Decoder) (string, error) {
}
return req.AccessToken, nil
}
+
+func getAccessToken(r *http.Request) (string, error) {
+ ctx, span := tracer.Start(r.Context(), "getAccessToken")
+ r = r.WithContext(ctx)
+ defer span.End()
+
+ authHeader := r.Header.Get("authorization")
+ if authHeader == "" {
+ return "", errors.New("no auth header")
+ }
+ parts := strings.Split(authHeader, "Bearer ")
+ if len(parts) != 2 {
+ return "", errors.New("invalid auth header")
+ }
+ return parts[1], nil
+}
+
+func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
+ ctx, span := tracer.Start(ctx, "getTokenIDAndSubject")
+ defer span.End()
+
+ tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
+ if err == nil {
+ splitToken := strings.Split(tokenIDSubject, ":")
+ if len(splitToken) != 2 {
+ return "", "", false
+ }
+ return splitToken[0], splitToken[1], true
+ }
+ accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
+ if err != nil {
+ return "", "", false
+ }
+ return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
+}
diff --git a/pkg/op/verifier_access_token.go b/pkg/op/verifier_access_token.go
new file mode 100644
index 0000000..585ca54
--- /dev/null
+++ b/pkg/op/verifier_access_token.go
@@ -0,0 +1,60 @@
+package op
+
+import (
+ "context"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type AccessTokenVerifier oidc.Verifier
+
+type AccessTokenVerifierOpt func(*AccessTokenVerifier)
+
+func WithSupportedAccessTokenSigningAlgorithms(algs ...string) AccessTokenVerifierOpt {
+ return func(verifier *AccessTokenVerifier) {
+ verifier.SupportedSignAlgs = algs
+ }
+}
+
+// NewAccessTokenVerifier returns a AccessTokenVerifier suitable for access token verification.
+func NewAccessTokenVerifier(issuer string, keySet oidc.KeySet, opts ...AccessTokenVerifierOpt) *AccessTokenVerifier {
+ verifier := &AccessTokenVerifier{
+ Issuer: issuer,
+ KeySet: keySet,
+ }
+ for _, opt := range opts {
+ opt(verifier)
+ }
+ return verifier
+}
+
+// VerifyAccessToken validates the access token (issuer, signature and expiration).
+func VerifyAccessToken[C oidc.Claims](ctx context.Context, token string, v *AccessTokenVerifier) (claims C, err error) {
+ ctx, span := tracer.Start(ctx, "VerifyAccessToken")
+ defer span.End()
+
+ var nilClaims C
+
+ decrypted, err := oidc.DecryptToken(token)
+ if err != nil {
+ return nilClaims, err
+ }
+ payload, err := oidc.ParseToken(decrypted, &claims)
+ if err != nil {
+ return nilClaims, err
+ }
+
+ if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
+ return nilClaims, err
+ }
+
+ return claims, nil
+}
diff --git a/pkg/op/verifier_access_token_example_test.go b/pkg/op/verifier_access_token_example_test.go
new file mode 100644
index 0000000..b97a7fd
--- /dev/null
+++ b/pkg/op/verifier_access_token_example_test.go
@@ -0,0 +1,70 @@
+package op_test
+
+import (
+ "context"
+ "fmt"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+)
+
+// MyCustomClaims extends the TokenClaims base,
+// so it implements the oidc.Claims interface.
+// Instead of carrying a map, we add needed fields// to the struct for type safe access.
+type MyCustomClaims struct {
+ oidc.TokenClaims
+ NotBefore oidc.Time `json:"nbf,omitempty"`
+ CodeHash string `json:"c_hash,omitempty"`
+ SessionID string `json:"sid,omitempty"`
+ Scopes []string `json:"scope,omitempty"`
+ AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
+ Foo string `json:"foo,omitempty"`
+ Bar *Nested `json:"bar,omitempty"`
+}
+
+// Nested struct types are also possible.
+type Nested struct {
+ Count int `json:"count,omitempty"`
+ Tags []string `json:"tags,omitempty"`
+}
+
+/*
+accessToken carries the following claims. foo and bar are custom claims
+
+ {
+ "aud": [
+ "unit",
+ "test"
+ ],
+ "bar": {
+ "count": 22,
+ "tags": [
+ "some",
+ "tags"
+ ]
+ },
+ "exp": 4802234675,
+ "foo": "Hello, World!",
+ "iat": 1678097014,
+ "iss": "local.com",
+ "jti": "9876",
+ "nbf": 1678097014,
+ "sub": "tim@local.com"
+ }
+*/
+const accessToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IjEifQ.eyJhdWQiOlsidW5pdCIsInRlc3QiXSwiYmFyIjp7ImNvdW50IjoyMiwidGFncyI6WyJzb21lIiwidGFncyJdfSwiZXhwIjo0ODAyMjM0Njc1LCJmb28iOiJIZWxsbywgV29ybGQhIiwiaWF0IjoxNjc4MDk3MDE0LCJpc3MiOiJsb2NhbC5jb20iLCJqdGkiOiI5ODc2IiwibmJmIjoxNjc4MDk3MDE0LCJzdWIiOiJ0aW1AbG9jYWwuY29tIn0.OUgk-B7OXjYlYFj-nogqSDJiQE19tPrbzqUHEAjcEiJkaWo6-IpGVfDiGKm-TxjXQsNScxpaY0Pg3XIh1xK6TgtfYtoLQm-5RYw_mXgb9xqZB2VgPs6nNEYFUDM513MOU0EBc0QMyqAEGzW-HiSPAb4ugCvkLtM1yo11Xyy6vksAdZNs_mJDT4X3vFXnr0jk0ugnAW6fTN3_voC0F_9HQUAkmd750OIxkAHxAMvEPQcpbLHenVvX_Q0QMrzClVrxehn5TVMfmkYYg7ocr876Bq9xQGPNHAcrwvVIJqdg5uMUA38L3HC2BEueG6furZGvc7-qDWAT1VR9liM5ieKpPg`
+
+func ExampleVerifyAccessToken_customClaims() {
+ v := op.NewAccessTokenVerifier("local.com", tu.KeySet{})
+
+ // VerifyAccessToken can be called with the *MyCustomClaims.
+ claims, err := op.VerifyAccessToken[*MyCustomClaims](context.TODO(), accessToken, v)
+ if err != nil {
+ panic(err)
+ }
+
+ // Here we have typesafe access to the custom claims
+ fmt.Println(claims.Foo, claims.Bar.Count, claims.Bar.Tags)
+ // Output: Hello, World! 22 [some tags]
+}
diff --git a/pkg/op/verifier_access_token_test.go b/pkg/op/verifier_access_token_test.go
new file mode 100644
index 0000000..5845f9f
--- /dev/null
+++ b/pkg/op/verifier_access_token_test.go
@@ -0,0 +1,126 @@
+package op
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAccessTokenVerifier(t *testing.T) {
+ type args struct {
+ issuer string
+ keySet oidc.KeySet
+ opts []AccessTokenVerifierOpt
+ }
+ tests := []struct {
+ name string
+ args args
+ want *AccessTokenVerifier
+ }{
+ {
+ name: "simple",
+ args: args{
+ issuer: tu.ValidIssuer,
+ keySet: tu.KeySet{},
+ },
+ want: &AccessTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ KeySet: tu.KeySet{},
+ },
+ },
+ {
+ name: "with signature algorithm",
+ args: args{
+ issuer: tu.ValidIssuer,
+ keySet: tu.KeySet{},
+ opts: []AccessTokenVerifierOpt{
+ WithSupportedAccessTokenSigningAlgorithms("ABC", "DEF"),
+ },
+ },
+ want: &AccessTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ KeySet: tu.KeySet{},
+ SupportedSignAlgs: []string{"ABC", "DEF"},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewAccessTokenVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestVerifyAccessToken(t *testing.T) {
+ verifier := &AccessTokenVerifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: 2 * time.Minute,
+ Offset: time.Second,
+ SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
+ KeySet: tu.KeySet{},
+ }
+
+ tests := []struct {
+ name string
+ tokenClaims func() (string, *oidc.AccessTokenClaims)
+ wantErr bool
+ }{
+ {
+ name: "success",
+ tokenClaims: tu.ValidAccessToken,
+ },
+ {
+ name: "parse err",
+ tokenClaims: func() (string, *oidc.AccessTokenClaims) { return "~~~~", nil },
+ wantErr: true,
+ },
+ {
+ name: "invalid signature",
+ tokenClaims: func() (string, *oidc.AccessTokenClaims) { return tu.InvalidSignatureToken, nil },
+ wantErr: true,
+ },
+ {
+ name: "wrong issuer",
+ tokenClaims: func() (string, *oidc.AccessTokenClaims) {
+ return tu.NewAccessToken(
+ "foo", tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidJWTID, tu.ValidClientID,
+ tu.ValidSkew,
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "expired",
+ tokenClaims: func() (string, *oidc.AccessTokenClaims) {
+ return tu.NewAccessToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.Add(-time.Hour), tu.ValidJWTID, tu.ValidClientID,
+ tu.ValidSkew,
+ )
+ },
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ token, want := tt.tokenClaims()
+
+ got, err := VerifyAccessToken[*oidc.AccessTokenClaims](context.Background(), token, verifier)
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, got)
+ return
+ }
+ require.NoError(t, err)
+ require.NotNil(t, got)
+ assert.Equal(t, got, want)
+ })
+ }
+}
diff --git a/pkg/op/verifier_id_token_hint.go b/pkg/op/verifier_id_token_hint.go
new file mode 100644
index 0000000..02610aa
--- /dev/null
+++ b/pkg/op/verifier_id_token_hint.go
@@ -0,0 +1,87 @@
+package op
+
+import (
+ "context"
+ "errors"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+type IDTokenHintVerifier oidc.Verifier
+
+type IDTokenHintVerifierOpt func(*IDTokenHintVerifier)
+
+func WithSupportedIDTokenHintSigningAlgorithms(algs ...string) IDTokenHintVerifierOpt {
+ return func(verifier *IDTokenHintVerifier) {
+ verifier.SupportedSignAlgs = algs
+ }
+}
+
+func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHintVerifierOpt) *IDTokenHintVerifier {
+ verifier := &IDTokenHintVerifier{
+ Issuer: issuer,
+ KeySet: keySet,
+ }
+ for _, opt := range opts {
+ opt(verifier)
+ }
+ return verifier
+}
+
+type IDTokenHintExpiredError struct {
+ error
+}
+
+func (e IDTokenHintExpiredError) Unwrap() error {
+ return e.error
+}
+
+func (e IDTokenHintExpiredError) Is(err error) bool {
+ return errors.Is(err, e.error)
+}
+
+// VerifyIDTokenHint validates the id token according to
+// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation.
+// In case of an expired token both the Claims and first encountered expiry related error
+// is returned of type [IDTokenHintExpiredError]. In that case the caller can choose to still
+// trust the token for cases like logout, as signature and other verifications succeeded.
+func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) {
+ ctx, span := tracer.Start(ctx, "VerifyIDTokenHint")
+ defer span.End()
+
+ var nilClaims C
+
+ decrypted, err := oidc.DecryptToken(token)
+ if err != nil {
+ return nilClaims, err
+ }
+ payload, err := oidc.ParseToken(decrypted, &claims)
+ if err != nil {
+ return nilClaims, err
+ }
+
+ if err := oidc.CheckIssuer(claims, v.Issuer); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckSignature(ctx, decrypted, payload, claims, v.SupportedSignAlgs, v.KeySet); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
+ return nilClaims, err
+ }
+
+ if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
+ return claims, IDTokenHintExpiredError{err}
+ }
+
+ if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
+ return claims, IDTokenHintExpiredError{err}
+ }
+
+ if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
+ return claims, IDTokenHintExpiredError{err}
+ }
+ return claims, nil
+}
diff --git a/pkg/op/verifier_id_token_hint_test.go b/pkg/op/verifier_id_token_hint_test.go
new file mode 100644
index 0000000..347e33c
--- /dev/null
+++ b/pkg/op/verifier_id_token_hint_test.go
@@ -0,0 +1,172 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewIDTokenHintVerifier(t *testing.T) {
+ type args struct {
+ issuer string
+ keySet oidc.KeySet
+ opts []IDTokenHintVerifierOpt
+ }
+ tests := []struct {
+ name string
+ args args
+ want *IDTokenHintVerifier
+ }{
+ {
+ name: "simple",
+ args: args{
+ issuer: tu.ValidIssuer,
+ keySet: tu.KeySet{},
+ },
+ want: &IDTokenHintVerifier{
+ Issuer: tu.ValidIssuer,
+ KeySet: tu.KeySet{},
+ },
+ },
+ {
+ name: "with signature algorithm",
+ args: args{
+ issuer: tu.ValidIssuer,
+ keySet: tu.KeySet{},
+ opts: []IDTokenHintVerifierOpt{
+ WithSupportedIDTokenHintSigningAlgorithms("ABC", "DEF"),
+ },
+ },
+ want: &IDTokenHintVerifier{
+ Issuer: tu.ValidIssuer,
+ KeySet: tu.KeySet{},
+ SupportedSignAlgs: []string{"ABC", "DEF"},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewIDTokenHintVerifier(tt.args.issuer, tt.args.keySet, tt.args.opts...)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func Test_IDTokenHintExpiredError(t *testing.T) {
+ var err error = IDTokenHintExpiredError{oidc.ErrExpired}
+ assert.True(t, errors.Unwrap(err) == oidc.ErrExpired)
+ assert.ErrorIs(t, err, oidc.ErrExpired)
+ assert.ErrorAs(t, err, &IDTokenHintExpiredError{})
+}
+
+func TestVerifyIDTokenHint(t *testing.T) {
+ verifier := &IDTokenHintVerifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: 2 * time.Minute,
+ Offset: time.Second,
+ SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)},
+ MaxAge: 2 * time.Minute,
+ ACR: tu.ACRVerify,
+ KeySet: tu.KeySet{},
+ }
+
+ tests := []struct {
+ name string
+ tokenClaims func() (string, *oidc.IDTokenClaims)
+ wantClaims bool
+ wantErr error
+ }{
+ {
+ name: "success",
+ tokenClaims: tu.ValidIDToken,
+ wantClaims: true,
+ },
+ {
+ name: "parse err",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
+ wantErr: oidc.ErrParse,
+ },
+ {
+ name: "invalid signature",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
+ wantErr: oidc.ErrSignatureUnsupportedAlg,
+ },
+ {
+ name: "wrong issuer",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ "foo", tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: oidc.ErrIssuerInvalid,
+ },
+ {
+ name: "wrong acr",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ "else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantErr: oidc.ErrAcrInvalid,
+ },
+ {
+ name: "expired",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantClaims: true,
+ wantErr: IDTokenHintExpiredError{oidc.ErrExpired},
+ },
+ {
+ name: "IAT too old",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, time.Hour, "",
+ )
+ },
+ wantClaims: true,
+ wantErr: IDTokenHintExpiredError{oidc.ErrIatToOld},
+ },
+ {
+ name: "expired auth",
+ tokenClaims: func() (string, *oidc.IDTokenClaims) {
+ return tu.NewIDToken(
+ tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
+ tu.ValidExpiration, tu.ValidAuthTime.Add(-time.Hour), tu.ValidNonce,
+ tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
+ )
+ },
+ wantClaims: true,
+ wantErr: IDTokenHintExpiredError{oidc.ErrAuthTimeToOld},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ token, want := tt.tokenClaims()
+
+ got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
+ require.ErrorIs(t, err, tt.wantErr)
+ if tt.wantClaims {
+ assert.Equal(t, got, want, "claims")
+ return
+ }
+ assert.Nil(t, got, "claims")
+ })
+ }
+}
diff --git a/pkg/op/verifier_jwt_profile.go b/pkg/op/verifier_jwt_profile.go
new file mode 100644
index 0000000..85bfb14
--- /dev/null
+++ b/pkg/op/verifier_jwt_profile.go
@@ -0,0 +1,130 @@
+package op
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ jose "github.com/go-jose/go-jose/v4"
+
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+)
+
+// JWTProfileVerfiier extends oidc.Verifier with
+// a jwtProfileKeyStorage and a function to check
+// the subject in a token.
+type JWTProfileVerifier struct {
+ oidc.Verifier
+ Storage JWTProfileKeyStorage
+ keySet oidc.KeySet
+ CheckSubject func(request *oidc.JWTTokenRequest) error
+}
+
+// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
+func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
+ return newJWTProfileVerifier(storage, nil, issuer, maxAgeIAT, offset, opts...)
+}
+
+// NewJWTProfileVerifierKeySet creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
+func NewJWTProfileVerifierKeySet(keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
+ return newJWTProfileVerifier(nil, keySet, issuer, maxAgeIAT, offset, opts...)
+}
+
+func newJWTProfileVerifier(storage JWTProfileKeyStorage, keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
+ j := &JWTProfileVerifier{
+ Verifier: oidc.Verifier{
+ Issuer: issuer,
+ MaxAgeIAT: maxAgeIAT,
+ Offset: offset,
+ },
+ Storage: storage,
+ keySet: keySet,
+ CheckSubject: SubjectIsIssuer,
+ }
+
+ for _, opt := range opts {
+ opt(j)
+ }
+
+ return j
+}
+
+type JWTProfileVerifierOption func(*JWTProfileVerifier)
+
+// SubjectCheck sets a custom function to check the subject.
+// Defaults to SubjectIsIssuer()
+func SubjectCheck(check func(request *oidc.JWTTokenRequest) error) JWTProfileVerifierOption {
+ return func(verifier *JWTProfileVerifier) {
+ verifier.CheckSubject = check
+ }
+}
+
+// VerifyJWTAssertion verifies the assertion string from JWT Profile (authorization grant and client authentication)
+//
+// checks audience, exp, iat, signature and that issuer and sub are the same
+func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVerifier) (*oidc.JWTTokenRequest, error) {
+ ctx, span := tracer.Start(ctx, "VerifyJWTAssertion")
+ defer span.End()
+
+ request := new(oidc.JWTTokenRequest)
+ payload, err := oidc.ParseToken(assertion, request)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = oidc.CheckAudience(request, v.Issuer); err != nil {
+ return nil, err
+ }
+
+ if err = oidc.CheckExpiration(request, v.Offset); err != nil {
+ return nil, err
+ }
+
+ if err = oidc.CheckIssuedAt(request, v.MaxAgeIAT, v.Offset); err != nil {
+ return nil, err
+ }
+
+ if err = v.CheckSubject(request); err != nil {
+ return nil, err
+ }
+
+ keySet := v.keySet
+ if keySet == nil {
+ keySet = &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
+ }
+ if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
+ return nil, err
+ }
+ return request, nil
+}
+
+type JWTProfileKeyStorage interface {
+ GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error)
+}
+
+// SubjectIsIssuer
+func SubjectIsIssuer(request *oidc.JWTTokenRequest) error {
+ if request.Issuer != request.Subject {
+ return errors.New("delegation not allowed, issuer and sub must be identical")
+ }
+ return nil
+}
+
+type jwtProfileKeySet struct {
+ storage JWTProfileKeyStorage
+ clientID string
+}
+
+// VerifySignature implements oidc.KeySet by getting the public key from Storage implementation
+func (k *jwtProfileKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) {
+ ctx, span := tracer.Start(ctx, "VerifySignature")
+ defer span.End()
+
+ keyID, _ := oidc.GetKeyIDAndAlg(jws)
+ key, err := k.storage.GetKeyByIDAndClientID(ctx, keyID, k.clientID)
+ if err != nil {
+ return nil, fmt.Errorf("error fetching keys: %w", err)
+ }
+ return jws.Verify(key)
+}
diff --git a/pkg/op/verifier_jwt_profile_test.go b/pkg/op/verifier_jwt_profile_test.go
new file mode 100644
index 0000000..2068678
--- /dev/null
+++ b/pkg/op/verifier_jwt_profile_test.go
@@ -0,0 +1,117 @@
+package op_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ tu "git.christmann.info/LARA/zitadel-oidc/v3/internal/testutil"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/oidc"
+ "git.christmann.info/LARA/zitadel-oidc/v3/pkg/op"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewJWTProfileVerifier(t *testing.T) {
+ want := &op.JWTProfileVerifier{
+ Verifier: oidc.Verifier{
+ Issuer: tu.ValidIssuer,
+ MaxAgeIAT: time.Minute,
+ Offset: time.Second,
+ },
+ Storage: tu.JWTProfileKeyStorage{},
+ }
+ got := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, time.Second, op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
+ return oidc.ErrSubjectMissing
+ }))
+ assert.Equal(t, want.Verifier, got.Verifier)
+ assert.Equal(t, want.Storage, got.Storage)
+ assert.ErrorIs(t, got.CheckSubject(nil), oidc.ErrSubjectMissing)
+}
+
+func TestVerifyJWTAssertion(t *testing.T) {
+ errCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ verifier := op.NewJWTProfileVerifier(tu.JWTProfileKeyStorage{}, tu.ValidIssuer, time.Minute, 0)
+ tests := []struct {
+ name string
+ ctx context.Context
+ newToken func() (string, *oidc.JWTTokenRequest)
+ wantErr bool
+ }{
+ {
+ name: "parse error",
+ ctx: context.Background(),
+ newToken: func() (string, *oidc.JWTTokenRequest) { return "!", nil },
+ wantErr: true,
+ },
+ {
+ name: "wrong audience",
+ ctx: context.Background(),
+ newToken: func() (string, *oidc.JWTTokenRequest) {
+ return tu.NewJWTProfileAssertion(
+ tu.ValidClientID, tu.ValidClientID, []string{"wrong"},
+ time.Now(), tu.ValidExpiration,
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "expired",
+ ctx: context.Background(),
+ newToken: func() (string, *oidc.JWTTokenRequest) {
+ return tu.NewJWTProfileAssertion(
+ tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
+ time.Now(), time.Now().Add(-time.Hour),
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid iat",
+ ctx: context.Background(),
+ newToken: func() (string, *oidc.JWTTokenRequest) {
+ return tu.NewJWTProfileAssertion(
+ tu.ValidClientID, tu.ValidClientID, []string{tu.ValidIssuer},
+ time.Now().Add(time.Hour), tu.ValidExpiration,
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "invalid subject",
+ ctx: context.Background(),
+ newToken: func() (string, *oidc.JWTTokenRequest) {
+ return tu.NewJWTProfileAssertion(
+ tu.ValidClientID, "wrong", []string{tu.ValidIssuer},
+ time.Now(), tu.ValidExpiration,
+ )
+ },
+ wantErr: true,
+ },
+ {
+ name: "check signature fail",
+ ctx: errCtx,
+ newToken: tu.ValidJWTProfileAssertion,
+ wantErr: true,
+ },
+ {
+ name: "ok",
+ ctx: context.Background(),
+ newToken: tu.ValidJWTProfileAssertion,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assertion, want := tt.newToken()
+ got, err := op.VerifyJWTAssertion(tt.ctx, assertion, verifier)
+ if tt.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, want, got)
+ })
+ }
+}
diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go
deleted file mode 100644
index 6c9208d..0000000
--- a/pkg/rp/default_rp.go
+++ /dev/null
@@ -1,311 +0,0 @@
-package rp
-
-import (
- "context"
- "encoding/base64"
- "net/http"
- "strings"
-
- "github.com/caos/oidc/pkg/oidc/grants"
-
- "golang.org/x/oauth2"
-
- "github.com/caos/oidc/pkg/oidc"
- grants_tx "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
- "github.com/caos/oidc/pkg/utils"
-)
-
-const (
- idTokenKey = "id_token"
- stateParam = "state"
- pkceCode = "pkce"
-)
-
-var (
- DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
- http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
- }
-)
-
-//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface
-type DefaultRP struct {
- endpoints Endpoints
-
- oauthConfig oauth2.Config
- config *Config
- pkce bool
-
- httpClient *http.Client
- cookieHandler *utils.CookieHandler
-
- errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
-
- verifier Verifier
- onlyOAuth2 bool
-}
-
-//NewDefaultRP creates `DefaultRP` with the given
-//Config and possible configOptions
-//it will run discovery on the provided issuer
-//if no verifier is provided using the options the `DefaultVerifier` is set
-func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExchangeRP, error) {
- foundOpenID := false
- for _, scope := range rpConfig.Scopes {
- if scope == "openid" {
- foundOpenID = true
- }
- }
-
- p := &DefaultRP{
- config: rpConfig,
- httpClient: utils.DefaultHTTPClient,
- onlyOAuth2: !foundOpenID,
- }
-
- for _, optFunc := range rpOpts {
- optFunc(p)
- }
-
- if rpConfig.Endpoints.TokenURL != "" && rpConfig.Endpoints.AuthURL != "" {
- p.oauthConfig = p.getOAuthConfig(rpConfig.Endpoints)
- } else {
- if err := p.discover(); err != nil {
- return nil, err
- }
- }
-
- if p.errorHandler == nil {
- p.errorHandler = DefaultErrorHandler
- }
-
- if p.verifier == nil {
- p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
- }
-
- return p, nil
-}
-
-//DefaultRPOpts is the type for providing dynamic options to the DefaultRP
-type DefaultRPOpts func(p *DefaultRP)
-
-//WithCookieHandler set a `CookieHandler` for securing the various redirects
-func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
- return func(p *DefaultRP) {
- p.cookieHandler = cookieHandler
- }
-}
-
-//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
-//it also sets a `CookieHandler` for securing the various redirects
-//and exchanging the code challenge
-func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
- return func(p *DefaultRP) {
- p.pkce = true
- p.cookieHandler = cookieHandler
- }
-}
-
-//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
-func WithHTTPClient(client *http.Client) DefaultRPOpts {
- return func(p *DefaultRP) {
- p.httpClient = client
- }
-}
-
-//AuthURL is the `RelayingParty` interface implementation
-//wrapping the oauth2 `AuthCodeURL`
-//returning the url of the auth request
-func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
- authOpts := make([]oauth2.AuthCodeOption, 0)
- for _, opt := range opts {
- authOpts = append(authOpts, opt()...)
- }
- return p.oauthConfig.AuthCodeURL(state, authOpts...)
-}
-
-//AuthURL is the `RelayingParty` interface implementation
-//extending the `AuthURL` method with a http redirect handler
-func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- opts := make([]AuthURLOpt, 0)
- if err := p.trySetStateCookie(w, state); err != nil {
- http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
- return
- }
- if p.pkce {
- codeChallenge, err := p.generateAndStoreCodeChallenge(w)
- if err != nil {
- http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
- return
- }
- opts = append(opts, WithCodeChallenge(codeChallenge))
- }
- http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
- }
-}
-
-func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
- var codeVerifier string
- codeVerifier = "s"
- if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
- return "", err
- }
- return oidc.NewSHACodeChallenge(codeVerifier), nil
-}
-
-//AuthURL is the `RelayingParty` interface implementation
-//handling the oauth2 code exchange, extracting and validating the id_token
-//returning it paresed together with the oauth2 tokens (access, refresh)
-func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
- ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
- codeOpts := make([]oauth2.AuthCodeOption, 0)
- for _, opt := range opts {
- codeOpts = append(codeOpts, opt()...)
- }
-
- token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
- if err != nil {
- return nil, err //TODO: our error
- }
- idTokenString, ok := token.Extra(idTokenKey).(string)
- if !ok {
- //TODO: implement
- }
-
- idToken := new(oidc.IDTokenClaims)
- if !p.onlyOAuth2 {
- idToken, err = p.verifier.Verify(ctx, token.AccessToken, idTokenString)
- if err != nil {
- return nil, err //TODO: err
- }
- }
-
- return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
-}
-
-//AuthURL is the `RelayingParty` interface implementation
-//extending the `CodeExchange` method with callback function
-func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- state, err := p.tryReadStateCookie(w, r)
- if err != nil {
- http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
- return
- }
- params := r.URL.Query()
- if params.Get("error") != "" {
- p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
- return
- }
- codeOpts := make([]CodeExchangeOpt, 0)
- if p.pkce {
- codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
- if err != nil {
- http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
- return
- }
- codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
- }
- tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts...)
- if err != nil {
- http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
- return
- }
- callback(w, r, tokens, state)
- }
-}
-
-// func (p *DefaultRP) Introspect(ctx context.Context, accessToken string) (oidc.TokenIntrospectResponse, error) {
-// // req := &http.Request{}
-// // resp, err := p.httpClient.Do(req)
-// // if err != nil {
-
-// // }
-// // p.endpoints.IntrospectURL
-// return nil, nil
-// }
-
-func (p *DefaultRP) Userinfo() {}
-
-//ClientCredentials is the `RelayingParty` interface implementation
-//handling the oauth2 client credentials grant
-func (p *DefaultRP) ClientCredentials(ctx context.Context, scopes ...string) (newToken *oauth2.Token, err error) {
- return p.callTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...))
-}
-
-//TokenExchange is the `TokenExchangeRP` interface implementation
-//handling the oauth2 token exchange (draft)
-func (p *DefaultRP) TokenExchange(ctx context.Context, request *grants_tx.TokenExchangeRequest) (newToken *oauth2.Token, err error) {
- return p.callTokenEndpoint(request)
-}
-
-//DelegationTokenExchange is the `TokenExchangeRP` interface implementation
-//handling the oauth2 token exchange for a delegation token (draft)
-func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken string, reqOpts ...grants_tx.TokenExchangeOption) (newToken *oauth2.Token, err error) {
- return p.TokenExchange(ctx, DelegationTokenRequest(subjectToken, reqOpts...))
-}
-
-func (p *DefaultRP) discover() error {
- wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint
- req, err := http.NewRequest("GET", wellKnown, nil)
- if err != nil {
- return err
- }
- discoveryConfig := new(oidc.DiscoveryConfiguration)
- err = utils.HttpRequest(p.httpClient, req, &discoveryConfig)
- if err != nil {
- return err
- }
- p.endpoints = GetEndpoints(discoveryConfig)
- p.oauthConfig = p.getOAuthConfig(p.endpoints.Endpoint)
- return nil
-}
-
-func (p *DefaultRP) getOAuthConfig(endpoint oauth2.Endpoint) oauth2.Config {
- return oauth2.Config{
- ClientID: p.config.ClientID,
- ClientSecret: p.config.ClientSecret,
- Endpoint: endpoint,
- RedirectURL: p.config.CallbackURL,
- Scopes: p.config.Scopes,
- }
-}
-
-func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Token, err error) {
- req, err := utils.FormRequest(p.endpoints.TokenURL, request)
- if err != nil {
- return nil, err
- }
- auth := base64.StdEncoding.EncodeToString([]byte(p.config.ClientID + ":" + p.config.ClientSecret))
- req.Header.Set("Authorization", "Basic "+auth)
- token := new(oauth2.Token)
- if err := utils.HttpRequest(p.httpClient, req, token); err != nil {
- return nil, err
- }
- return token, nil
-}
-
-func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
- if p.cookieHandler != nil {
- if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
- return err
- }
- }
- return nil
-}
-
-func (p *DefaultRP) tryReadStateCookie(w http.ResponseWriter, r *http.Request) (state string, err error) {
- if p.cookieHandler == nil {
- return r.FormValue(stateParam), nil
- }
- state, err = p.cookieHandler.CheckQueryCookie(r, stateParam)
- if err != nil {
- return "", err
- }
- p.cookieHandler.DeleteCookie(w, stateParam)
- return state, nil
-}
-
-func (p *DefaultRP) Client(ctx context.Context, token *oauth2.Token) *http.Client {
- return p.oauthConfig.Client(ctx, token)
-}
diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go
deleted file mode 100644
index db599e3..0000000
--- a/pkg/rp/default_verifier.go
+++ /dev/null
@@ -1,387 +0,0 @@
-package rp
-
-import (
- "bytes"
- "context"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "strings"
- "time"
-
- "gopkg.in/square/go-jose.v2"
-
- "github.com/caos/oidc/pkg/oidc"
- "github.com/caos/oidc/pkg/utils"
-)
-
-//DefaultVerifier implements the `Verifier` interface
-type DefaultVerifier struct {
- config *verifierConfig
- keySet oidc.KeySet
-}
-
-//ConfFunc is the type for providing dynamic options to the DefaultVerfifier
-type ConfFunc func(*verifierConfig)
-
-//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
-type ACRVerifier func(string) error
-
-//NewDefaultVerifier creates `DefaultVerifier` with the given
-//issuer, clientID, keyset and possible configOptions
-func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier {
- conf := &verifierConfig{
- issuer: issuer,
- clientID: clientID,
- iat: &iatConfig{
- // offset: time.Duration(500 * time.Millisecond),
- },
- }
-
- for _, opt := range confOpts {
- if opt != nil {
- opt(conf)
- }
- }
- return &DefaultVerifier{config: conf, keySet: keySet}
-}
-
-//WithIgnoreAudience will turn off validation for audience claim (should only be used for id_token_hints)
-func WithIgnoreAudience() func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.ignoreAudience = true
- }
-}
-
-//WithIgnoreExpiration will turn off validation for expiration claim (should only be used for id_token_hints)
-func WithIgnoreExpiration() func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.ignoreExpiration = true
- }
-}
-
-//WithIgnoreIssuedAt will turn off iat claim verification
-func WithIgnoreIssuedAt() func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.iat.ignore = true
- }
-}
-
-//WithIssuedAtOffset mitigates the risk of iat to be in the future
-//because of clock skews with the ability to add an offset to the current time
-func WithIssuedAtOffset(offset time.Duration) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.iat.offset = offset
- }
-}
-
-//WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
-func WithIssuedAtMaxAge(maxAge time.Duration) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.iat.maxAge = maxAge
- }
-}
-
-//WithNonce TODO: ?
-func WithNonce(nonce string) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.nonce = nonce
- }
-}
-
-//WithACRVerifier sets the verifier for the acr claim
-func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.acr = verifier
- }
-}
-
-//WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
-func WithAuthTimeMaxAge(maxAge time.Duration) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.maxAge = maxAge
- }
-}
-
-//WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
-func WithSupportedSigningAlgorithms(algs ...string) func(*verifierConfig) {
- return func(conf *verifierConfig) {
- conf.supportedSignAlgs = algs
- }
-}
-
-type verifierConfig struct {
- issuer string
- clientID string
- nonce string
- ignoreAudience bool
- ignoreExpiration bool
- iat *iatConfig
- acr ACRVerifier
- maxAge time.Duration
- supportedSignAlgs []string
-
- // httpClient *http.Client
-
- now time.Time
-}
-
-type iatConfig struct {
- ignore bool
- offset time.Duration
- maxAge time.Duration
-}
-
-//DefaultACRVerifier implements `ACRVerifier` returning an error
-//if non of the provided values matches the acr claim
-func DefaultACRVerifier(possibleValues []string) ACRVerifier {
- return func(acr string) error {
- if !utils.Contains(possibleValues, acr) {
- return ErrAcrInvalid(possibleValues, acr)
- }
- return nil
- }
-}
-
-//Verify implements the `Verify` method of the `Verifier` interface
-//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
-func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
- v.config.now = time.Now().UTC()
- idToken, err := v.VerifyIDToken(ctx, idTokenString)
- if err != nil {
- return nil, err
- }
- if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
- return nil, err
- }
- return idToken, nil
-}
-
-func (v *DefaultVerifier) now() time.Time {
- if v.config.now.IsZero() {
- v.config.now = time.Now().UTC().Round(time.Second)
- }
- return v.config.now
-}
-
-//VerifyIDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
-func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
- //1. if encrypted --> decrypt
- decrypted, err := v.decryptToken(idTokenString)
- if err != nil {
- return nil, err
- }
- claims, payload, err := v.parseToken(decrypted)
- if err != nil {
- return nil, err
- }
- // token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
- //2, check issuer (exact match)
- if err := v.checkIssuer(claims.Issuer); err != nil {
- return nil, err
- }
-
- //3. check aud (aud must contain client_id, all aud strings must be allowed)
- if err = v.checkAudience(claims.Audiences); err != nil {
- return nil, err
- }
-
- if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
- return nil, err
- }
-
- //6. check signature by keys
- //7. check alg default is rs256
- //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
- claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
- if err != nil {
- return nil, err
- }
-
- //9. check exp before now
- if err = v.checkExpiration(claims.Expiration); err != nil {
- return nil, err
- }
-
- //10. check iat duration is optional (can be checked)
- if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
- return nil, err
- }
-
- //11. check nonce (check if optional possible) id_token.nonce == sentNonce
- if err = v.checkNonce(claims.Nonce); err != nil {
- return nil, err
- }
-
- //12. if acr requested check acr
- if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
- return nil, err
- }
-
- //13. if auth_time requested check if auth_time is less than max age
- if err = v.checkAuthTime(claims.AuthTime); err != nil {
- return nil, err
- }
-
- return claims, nil
-}
-
-func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) {
- parts := strings.Split(tokenString, ".")
- if len(parts) != 3 {
- return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
- }
- payload, err := base64.RawURLEncoding.DecodeString(parts[1])
- if err != nil {
- return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
- }
- idToken := new(oidc.IDTokenClaims)
- err = json.Unmarshal(payload, idToken)
- return idToken, payload, err
-}
-
-func (v *DefaultVerifier) checkIssuer(issuer string) error {
- if v.config.issuer != issuer {
- return ErrIssuerInvalid(v.config.issuer, issuer)
- }
- return nil
-}
-
-func (v *DefaultVerifier) checkAudience(audiences []string) error {
- if v.config.ignoreAudience {
- return nil
- }
- if !utils.Contains(audiences, v.config.clientID) {
- return ErrAudienceMissingClientID(v.config.clientID)
- }
-
- //TODO: check aud trusted
- return nil
-}
-
-//4. if multiple aud strings --> check if azp
-//5. if azp --> check azp == client_id
-func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
- if v.config.ignoreAudience {
- return nil
- }
- if len(audiences) > 1 {
- if authorizedParty == "" {
- return ErrAzpMissing()
- }
- }
- if authorizedParty != "" && authorizedParty != v.config.clientID {
- return ErrAzpInvalid(authorizedParty, v.config.clientID)
- }
- return nil
-}
-
-func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) {
- jws, err := jose.ParseSigned(idTokenString)
- if err != nil {
- return "", err
- }
- if len(jws.Signatures) == 0 {
- return "", ErrSignatureMissing()
- }
- if len(jws.Signatures) > 1 {
- return "", ErrSignatureMultiple()
- }
- sig := jws.Signatures[0]
- supportedSigAlgs := v.config.supportedSignAlgs
- if len(supportedSigAlgs) == 0 {
- supportedSigAlgs = []string{"RS256"}
- }
- if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
- return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
- }
-
- signedPayload, err := v.keySet.VerifySignature(ctx, jws)
- if err != nil {
- return "", err
- }
-
- if !bytes.Equal(signedPayload, payload) {
- return "", ErrSignatureInvalidPayload()
- }
- return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
-}
-
-func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
- if v.config.ignoreExpiration {
- return nil
- }
- expiration = expiration.Round(time.Second)
- if !v.now().Before(expiration) {
- return ErrExpInvalid(expiration)
- }
- return nil
-}
-
-func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error {
- if v.config.iat.ignore {
- return nil
- }
- issuedAt = issuedAt.Round(time.Second)
- offset := v.now().Add(v.config.iat.offset).Round(time.Second)
- if issuedAt.After(offset) {
- return ErrIatInFuture(issuedAt, offset)
- }
- if v.config.iat.maxAge == 0 {
- return nil
- }
- maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
- if issuedAt.Before(maxAge) {
- return ErrIatToOld(maxAge, issuedAt)
- }
- return nil
-}
-func (v *DefaultVerifier) checkNonce(nonce string) error {
- if v.config.nonce == "" {
- return nil
- }
- if v.config.nonce != nonce {
- return ErrNonceInvalid(v.config.nonce, nonce)
- }
- return nil
-}
-func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error {
- if v.config.acr != nil {
- return v.config.acr(acr)
- }
- return nil
-}
-func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error {
- if v.config.maxAge == 0 {
- return nil
- }
- if authTime.IsZero() {
- return ErrAuthTimeNotPresent()
- }
- authTime = authTime.Round(time.Second)
- maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
- if authTime.Before(maxAge) {
- return ErrAuthTimeToOld(maxAge, authTime)
- }
- return nil
-}
-
-func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) {
- return tokenString, nil //TODO: impl
-}
-
-func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
- if accessToken == "" {
- return nil
- }
-
- actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
- if err != nil {
- return err
- }
- if actual != atHash {
- return ErrAtHash()
- }
- return nil
-}
diff --git a/pkg/rp/delegation.go b/pkg/rp/delegation.go
deleted file mode 100644
index 3ae6bb6..0000000
--- a/pkg/rp/delegation.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package rp
-
-import (
- "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
-)
-
-//DelegationTokenRequest is an implementation of TokenExchangeRequest
-//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional
-//"urn:ietf:params:oauth:token-type:access_token" actor token for a
-//"urn:ietf:params:oauth:token-type:access_token" delegation token
-func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
- return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
-}
diff --git a/pkg/rp/error.go b/pkg/rp/error.go
deleted file mode 100644
index fa0ece9..0000000
--- a/pkg/rp/error.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package rp
-
-import (
- "fmt"
- "time"
-)
-
-var (
- ErrIssuerInvalid = func(expected, actual string) *validationError {
- return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
- }
- ErrAudienceMissingClientID = func(clientID string) *validationError {
- return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
- }
- ErrAzpMissing = func() *validationError {
- return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
- }
- ErrAzpInvalid = func(azp, clientID string) *validationError {
- return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
- }
- ErrExpInvalid = func(exp time.Time) *validationError {
- return ValidationError("Token has expired %v", exp)
- }
- ErrIatInFuture = func(exp, now time.Time) *validationError {
- return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
- }
- ErrIatToOld = func(maxAge, iat time.Time) *validationError {
- return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
- }
- ErrNonceInvalid = func(expected, actual string) *validationError {
- return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
- }
- ErrAcrInvalid = func(expected []string, actual string) *validationError {
- return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
- }
-
- ErrAuthTimeNotPresent = func() *validationError {
- return ValidationError("claim `auth_time` of token is missing")
- }
- ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
- return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
- }
- ErrSignatureMissing = func() *validationError {
- return ValidationError("id_token does not contain a signature")
- }
- ErrSignatureMultiple = func() *validationError {
- return ValidationError("id_token contains multiple signatures")
- }
- ErrSignatureInvalidPayload = func() *validationError {
- return ValidationError("Signature does not match Payload")
- }
- ErrAtHash = func() *validationError {
- return ValidationError("at_hash does not correspond to access token")
- }
-)
-
-func ValidationError(message string, args ...interface{}) *validationError {
- return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
-}
-
-type validationError struct {
- message string
-}
-
-func (v *validationError) Error() string {
- return v.message
-}
diff --git a/pkg/rp/jwks.go b/pkg/rp/jwks.go
deleted file mode 100644
index 97b1e6f..0000000
--- a/pkg/rp/jwks.go
+++ /dev/null
@@ -1,156 +0,0 @@
-package rp
-
-import (
- "context"
- "errors"
- "fmt"
- "net/http"
- "sync"
-
- "github.com/caos/oidc/pkg/utils"
-
- "gopkg.in/square/go-jose.v2"
-
- "github.com/caos/oidc/pkg/oidc"
-)
-
-func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
- return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
-}
-
-type remoteKeySet struct {
- jwksURL string
- httpClient *http.Client
-
- // guard all other fields
- mu sync.Mutex
-
- // inflight suppresses parallel execution of updateKeys and allows
- // multiple goroutines to wait for its result.
- inflight *inflight
-
- // A set of cached keys and their expiry.
- cachedKeys []jose.JSONWebKey
-}
-
-// inflight is used to wait on some in-flight request from multiple goroutines.
-type inflight struct {
- doneCh chan struct{}
-
- keys []jose.JSONWebKey
- err error
-}
-
-func newInflight() *inflight {
- return &inflight{doneCh: make(chan struct{})}
-}
-
-// wait returns a channel that multiple goroutines can receive on. Once it returns
-// a value, the inflight request is done and result() can be inspected.
-func (i *inflight) wait() <-chan struct{} {
- return i.doneCh
-}
-
-// done can only be called by a single goroutine. It records the result of the
-// inflight request and signals other goroutines that the result is safe to
-// inspect.
-func (i *inflight) done(keys []jose.JSONWebKey, err error) {
- i.keys = keys
- i.err = err
- close(i.doneCh)
-}
-
-// result cannot be called until the wait() channel has returned a value.
-func (i *inflight) result() ([]jose.JSONWebKey, error) {
- return i.keys, i.err
-}
-
-func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
- // We don't support JWTs signed with multiple signatures.
- keyID := ""
- for _, sig := range jws.Signatures {
- keyID = sig.Header.KeyID
- break
- }
-
- keys := r.keysFromCache()
- payload, err, ok := CheckKey(keyID, keys, jws)
- if ok {
- return payload, err
- }
-
- keys, err = r.keysFromRemote(ctx)
- if err != nil {
- return nil, fmt.Errorf("fetching keys %v", err)
- }
-
- payload, err, ok = CheckKey(keyID, keys, jws)
- if !ok {
- return nil, errors.New("invalid kid")
- }
- return payload, err
-}
-
-func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
- r.mu.Lock()
- defer r.mu.Unlock()
- return r.cachedKeys
-}
-
-// keysFromRemote syncs the key set from the remote set, records the values in the
-// cache, and returns the key set.
-func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
- // Need to lock to inspect the inflight request field.
- r.mu.Lock()
- // If there's not a current inflight request, create one.
- if r.inflight == nil {
- r.inflight = newInflight()
-
- // This goroutine has exclusive ownership over the current inflight
- // request. It releases the resource by nil'ing the inflight field
- // once the goroutine is done.
- go r.updateKeys(ctx)
- }
- inflight := r.inflight
- r.mu.Unlock()
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-inflight.wait():
- return inflight.result()
- }
-}
-
-func (r *remoteKeySet) updateKeys(ctx context.Context) {
- // Sync keys and finish inflight when that's done.
- keys, err := r.fetchRemoteKeys(ctx)
-
- r.inflight.done(keys, err)
-
- // Lock to update the keys and indicate that there is no longer an
- // inflight request.
- r.mu.Lock()
- defer r.mu.Unlock()
-
- if err == nil {
- r.cachedKeys = keys
- }
-
- // Free inflight so a different request can run.
- r.inflight = nil
-}
-
-func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
- req, err := http.NewRequest("GET", r.jwksURL, nil)
- if err != nil {
- return nil, fmt.Errorf("oidc: can't create request: %v", err)
- }
-
- keySet := new(jose.JSONWebKeySet)
- if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
- return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
- }
-
- return keySet.Keys, nil
-}
diff --git a/pkg/rp/jws.go b/pkg/rp/jws.go
deleted file mode 100644
index 20ab896..0000000
--- a/pkg/rp/jws.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package rp
-
-import (
- "gopkg.in/square/go-jose.v2"
-)
-
-func CheckKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) {
- for _, key := range keys {
- if keyID == "" || key.KeyID == keyID {
- payload, err := jws.Verify(&key)
- return payload, err, true
- }
- }
- return nil, nil, false
-}
diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go
deleted file mode 100644
index 8aba443..0000000
--- a/pkg/rp/relaying_party.go
+++ /dev/null
@@ -1,108 +0,0 @@
-package rp
-
-import (
- "context"
- "net/http"
-
- "github.com/caos/oidc/pkg/oidc"
-
- "golang.org/x/oauth2"
-)
-
-//RelayingParty declares the minimal interface for oidc clients
-type RelayingParty interface {
- //Client return a standard http client where the token can be used
- Client(ctx context.Context, token *oauth2.Token) *http.Client
-
- //AuthURL returns the authorization endpoint with a given state
- AuthURL(state string, opts ...AuthURLOpt) string
-
- //AuthURLHandler should implement the AuthURL func as http.HandlerFunc
- //(redirecting to the auth endpoint)
- AuthURLHandler(state string) http.HandlerFunc
-
- //CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
- //returning an `Access Token` and `ID Token Claims`
- CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error)
-
- //CodeExchangeHandler extends the CodeExchange func,
- //calling the provided callback func on success with additional returned `state`
- CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc
-
- //ClientCredentials implements the oauth2 Client Credentials Grant
- //requesting an `Access Token` for the client itself, without user context
- ClientCredentials(ctx context.Context, scopes ...string) (*oauth2.Token, error)
-
- //Introspects calls the Introspect Endpoint
- //for validating an (access) token
- // Introspect(ctx context.Context, token string) (TokenIntrospectResponse, error)
-
- //Userinfo implements the OIDC Userinfo call
- //returning the info of the user for the requested scopes of an access token
- Userinfo()
-}
-
-//PasswortGrantRP extends the `RelayingParty` interface with the oauth2 `Password Grant`
-//
-//This interface is separated from the standard `RelayingParty` interface as the `password grant`
-//is part of the oauth2 and therefore OIDC specification, but should only be used when there's no
-//other possibility, so IMHO never ever. Ever.
-type PasswortGrantRP interface {
- RelayingParty
-
- //PasswordGrant implements the oauth2 `Password Grant`,
- //requesting an access token with the users `username` and `password`
- PasswordGrant(context.Context, string, string) (*oauth2.Token, error)
-}
-
-type Config struct {
- ClientID string
- ClientSecret string
- CallbackURL string
- Issuer string
- Scopes []string
- Endpoints oauth2.Endpoint
-}
-
-type OptionFunc func(RelayingParty)
-
-type Endpoints struct {
- oauth2.Endpoint
- IntrospectURL string
- UserinfoURL string
- JKWsURL string
-}
-
-func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
- return Endpoints{
- Endpoint: oauth2.Endpoint{
- AuthURL: discoveryConfig.AuthorizationEndpoint,
- AuthStyle: oauth2.AuthStyleAutoDetect,
- TokenURL: discoveryConfig.TokenEndpoint,
- },
- IntrospectURL: discoveryConfig.IntrospectionEndpoint,
- UserinfoURL: discoveryConfig.UserinfoEndpoint,
- JKWsURL: discoveryConfig.JwksURI,
- }
-}
-
-type AuthURLOpt func() []oauth2.AuthCodeOption
-
-//WithCodeChallenge sets the `code_challenge` params in the auth request
-func WithCodeChallenge(codeChallenge string) AuthURLOpt {
- return func() []oauth2.AuthCodeOption {
- return []oauth2.AuthCodeOption{
- oauth2.SetAuthURLParam("code_challenge", codeChallenge),
- oauth2.SetAuthURLParam("code_challenge_method", "S256"),
- }
- }
-}
-
-type CodeExchangeOpt func() []oauth2.AuthCodeOption
-
-//WithCodeVerifier sets the `code_verifier` param in the token request
-func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
- return func() []oauth2.AuthCodeOption {
- return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
- }
-}
diff --git a/pkg/rp/tockenexchange.go b/pkg/rp/tockenexchange.go
deleted file mode 100644
index d84b38e..0000000
--- a/pkg/rp/tockenexchange.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package rp
-
-import (
- "context"
-
- "golang.org/x/oauth2"
-
- "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
-)
-
-//TokenExchangeRP extends the `RelayingParty` interface for the *draft* oauth2 `Token Exchange`
-type TokenExchangeRP interface {
- RelayingParty
-
- //TokenExchange implement the `Token Echange Grant` exchanging some token for an other
- TokenExchange(context.Context, *tokenexchange.TokenExchangeRequest) (*oauth2.Token, error)
-}
-
-//DelegationTokenExchangeRP extends the `TokenExchangeRP` interface
-//for the specific `delegation token` request
-type DelegationTokenExchangeRP interface {
- TokenExchangeRP
-
- //DelegationTokenExchange implement the `Token Exchange Grant`
- //providing an access token in request for a `delegation` token for a given resource / audience
- DelegationTokenExchange(context.Context, string, ...tokenexchange.TokenExchangeOption) (*oauth2.Token, error)
-}
diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go
deleted file mode 100644
index b82e6c2..0000000
--- a/pkg/rp/verifier.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package rp
-
-import (
- "context"
-
- "github.com/caos/oidc/pkg/oidc"
-)
-
-//Verifier implement the Token Response Validation as defined in OIDC specification
-//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
-type Verifier interface {
-
- //Verify checks the access_token and id_token and returns the `id token claims`
- Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error)
-}
diff --git a/pkg/strings/strings.go b/pkg/strings/strings.go
new file mode 100644
index 0000000..b8f43a1
--- /dev/null
+++ b/pkg/strings/strings.go
@@ -0,0 +1,9 @@
+package strings
+
+import "slices"
+
+// Deprecated: Use standard library [slices.Contains] instead.
+func Contains(list []string, needle string) bool {
+ // TODO(v4): remove package.
+ return slices.Contains(list, needle)
+}
diff --git a/pkg/utils/strings_test.go b/pkg/strings/strings_test.go
similarity index 98%
rename from pkg/utils/strings_test.go
rename to pkg/strings/strings_test.go
index 86af2af..78698d4 100644
--- a/pkg/utils/strings_test.go
+++ b/pkg/strings/strings_test.go
@@ -1,4 +1,4 @@
-package utils
+package strings
import "testing"
diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go
deleted file mode 100644
index 78c007f..0000000
--- a/pkg/utils/hash.go
+++ /dev/null
@@ -1,34 +0,0 @@
-package utils
-
-import (
- "crypto/sha256"
- "crypto/sha512"
- "encoding/base64"
- "fmt"
- "hash"
-
- "gopkg.in/square/go-jose.v2"
-)
-
-func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
- switch sigAlgorithm {
- case jose.RS256, jose.ES256, jose.PS256:
- return sha256.New(), nil
- case jose.RS384, jose.ES384, jose.PS384:
- return sha512.New384(), nil
- case jose.RS512, jose.ES512, jose.PS512:
- return sha512.New(), nil
- default:
- return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
- }
-}
-
-func HashString(hash hash.Hash, s string, firstHalf bool) string {
- hash.Write([]byte(s)) // hash documents that Write will never return an error
- size := hash.Size()
- if firstHalf {
- size = size / 2
- }
- sum := hash.Sum(nil)[:size]
- return base64.RawURLEncoding.EncodeToString(sum)
-}
diff --git a/pkg/utils/http.go b/pkg/utils/http.go
deleted file mode 100644
index 6ad7083..0000000
--- a/pkg/utils/http.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package utils
-
-import (
- "encoding/json"
- "fmt"
- "io/ioutil"
- "net/http"
- "net/url"
- "strings"
- "time"
-
- "github.com/gorilla/schema"
-)
-
-var (
- DefaultHTTPClient = &http.Client{
- Timeout: time.Duration(30 * time.Second),
- }
-)
-
-func FormRequest(endpoint string, request interface{}) (*http.Request, error) {
- form := make(map[string][]string)
- encoder := schema.NewEncoder()
- if err := encoder.Encode(request, form); err != nil {
- return nil, err
- }
- body := strings.NewReader(url.Values(form).Encode())
- req, err := http.NewRequest("POST", endpoint, body)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
- return req, nil
-}
-
-func HttpRequest(client *http.Client, req *http.Request, response interface{}) error {
- resp, err := client.Do(req)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
-
- body, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("unable to read response body: %v", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("http status not ok: %s %s", resp.Status, body)
- }
-
- err = json.Unmarshal(body, response)
- if err != nil {
- return fmt.Errorf("failed to unmarshal response: %v %s", err, body)
- }
- return nil
-}
-
-func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
- values := make(map[string][]string)
- err := encoder.Encode(resp, values)
- if err != nil {
- return "", err
- }
- v := url.Values(values)
- return v.Encode(), nil
-}
diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go
deleted file mode 100644
index e279341..0000000
--- a/pkg/utils/marshal.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package utils
-
-import (
- "encoding/json"
- "net/http"
-
- "github.com/sirupsen/logrus"
-)
-
-func MarshalJSON(w http.ResponseWriter, i interface{}) {
- b, err := json.Marshal(i)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- w.Header().Set("content-type", "application/json")
- _, err = w.Write(b)
- if err != nil {
- logrus.Error("error writing response")
- }
-}
diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go
deleted file mode 100644
index 5ffcd37..0000000
--- a/pkg/utils/strings.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package utils
-
-func Contains(list []string, needle string) bool {
- for _, item := range list {
- if item == needle {
- return true
- }
- }
- return false
-}