Merge branch 'main' into upgrade-guide
This commit is contained in:
commit
79a1bb9a43
47 changed files with 1464 additions and 410 deletions
10
.github/dependabot.yml
vendored
10
.github/dependabot.yml
vendored
|
@ -9,6 +9,16 @@ updates:
|
|||
commit-message:
|
||||
prefix: chore
|
||||
include: scope
|
||||
- package-ecosystem: gomod
|
||||
target-branch: "2.12.x"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: daily
|
||||
time: '04:00'
|
||||
open-pull-requests-limit: 10
|
||||
commit-message:
|
||||
prefix: chore
|
||||
include: scope
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
|
|
6
.github/workflows/codeql-analysis.yml
vendored
6
.github/workflows/codeql-analysis.yml
vendored
|
@ -29,7 +29,7 @@ jobs:
|
|||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
uses: github/codeql-action/init@v3
|
||||
# Override language selection by uncommenting this and choosing your languages
|
||||
with:
|
||||
languages: go
|
||||
|
@ -37,7 +37,7 @@ jobs:
|
|||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 https://git.io/JvXDl
|
||||
|
@ -51,4 +51,4 @@ jobs:
|
|||
# make release
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
uses: github/codeql-action/analyze@v3
|
||||
|
|
29
.github/workflows/issue.yml
vendored
29
.github/workflows/issue.yml
vendored
|
@ -4,15 +4,40 @@ on:
|
|||
issues:
|
||||
types:
|
||||
- opened
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
|
||||
jobs:
|
||||
add-to-project:
|
||||
name: Add issue to project
|
||||
name: Add issue and community pr to project
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/add-to-project@v0.5.0
|
||||
- name: add issue
|
||||
uses: actions/add-to-project@v0.5.0
|
||||
if: ${{ github.event_name == 'issues' }}
|
||||
with:
|
||||
# You can target a repository in a different organization
|
||||
# to the issue
|
||||
project-url: https://github.com/orgs/zitadel/projects/2
|
||||
github-token: ${{ secrets.ADD_TO_PROJECT_PAT }}
|
||||
- uses: tspascoal/get-user-teams-membership@v3
|
||||
id: checkUserMember
|
||||
if: github.actor != 'dependabot[bot]'
|
||||
with:
|
||||
username: ${{ github.actor }}
|
||||
GITHUB_TOKEN: ${{ secrets.ADD_TO_PROJECT_PAT }}
|
||||
- name: add pr
|
||||
uses: actions/add-to-project@v0.5.0
|
||||
if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'engineers')}}
|
||||
with:
|
||||
# You can target a repository in a different organization
|
||||
# to the issue
|
||||
project-url: https://github.com/orgs/zitadel/projects/2
|
||||
github-token: ${{ secrets.ADD_TO_PROJECT_PAT }}
|
||||
- uses: actions-ecosystem/action-add-labels@v1.1.3
|
||||
if: ${{ github.event_name == 'pull_request_target' && github.actor != 'dependabot[bot]' && !contains(steps.checkUserMember.outputs.teams, 'staff')}}
|
||||
with:
|
||||
github_token: ${{ secrets.ADD_TO_PROJECT_PAT }}
|
||||
labels: |
|
||||
os-contribution
|
||||
|
|
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
|
@ -2,6 +2,7 @@ name: Release
|
|||
on:
|
||||
push:
|
||||
branches:
|
||||
- "2.11.x"
|
||||
- main
|
||||
- next
|
||||
tags-ignore:
|
||||
|
@ -22,11 +23,11 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup go
|
||||
uses: actions/setup-go@v4
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: go test -race -v -coverprofile=profile.cov -coverpkg=./pkg/... ./pkg/...
|
||||
- uses: codecov/codecov-action@v3.1.4
|
||||
- uses: codecov/codecov-action@v4.0.1
|
||||
with:
|
||||
file: ./profile.cov
|
||||
name: codecov-go
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
module.exports = {
|
||||
branches: [
|
||||
{name: "2.11.x"},
|
||||
{name: "main"},
|
||||
{name: "next", prerelease: true},
|
||||
],
|
||||
|
|
|
@ -72,7 +72,7 @@ CLIENT_ID=web CLIENT_SECRET=secret ISSUER=http://oidc.local:9998/ SCOPES="openid
|
|||
| Code Flow | yes | yes | OpenID Connect Core 1.0, [Section 3.1][1] |
|
||||
| Implicit Flow | no[^1] | yes | OpenID Connect Core 1.0, [Section 3.2][2] |
|
||||
| Hybrid Flow | no | not yet | OpenID Connect Core 1.0, [Section 3.3][3] |
|
||||
| Client Credentials | not yet | yes | OpenID Connect Core 1.0, [Section 9][4] |
|
||||
| Client Credentials | yes | yes | OpenID Connect Core 1.0, [Section 9][4] |
|
||||
| Refresh Token | yes | yes | OpenID Connect Core 1.0, [Section 12][5] |
|
||||
| Discovery | yes | yes | OpenID Connect [Discovery][6] 1.0 |
|
||||
| JWT Profile | yes | yes | [RFC 7523][7] |
|
||||
|
|
45
SECURITY.md
45
SECURITY.md
|
@ -1,6 +1,6 @@
|
|||
# Security Policy
|
||||
|
||||
At ZITADEL we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team.
|
||||
Please refer to the security policy [on zitadel/zitadel](https://github.com/zitadel/zitadel/blob/main/SECURITY.md) which is applicable for all open source repositories of our organization.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
|
@ -9,43 +9,12 @@ We currently support the following version of the OIDC framework:
|
|||
| Version | Supported | Branch | Details |
|
||||
| -------- | ------------------ | ----------- | ------------------------------------ |
|
||||
| 0.x.x | :x: | | not maintained |
|
||||
| <1.13 | :x: | | not maintained |
|
||||
| 1.13.x | :lock: :warning: | [1.13.x][1] | security only, [community effort][2] |
|
||||
| 2.x.x | :heavy_check_mark: | [main][3] | supported |
|
||||
| 3.0.0-xx | :white_check_mark: | [next][4] | [developement branch][5] |
|
||||
| <2.11 | :x: | | not maintained |
|
||||
| 2.11.x | :lock: :warning: | [2.11.x][1] | security only, [community effort][2] |
|
||||
| 3.x.x | :heavy_check_mark: | [main][3] | supported |
|
||||
| 4.0.0-xx | :white_check_mark: | [next][4] | [development branch] |
|
||||
|
||||
[1]: https://github.com/zitadel/oidc/tree/1.13.x
|
||||
[2]: https://github.com/zitadel/oidc/discussions/378
|
||||
[1]: https://github.com/zitadel/oidc/tree/2.11.x
|
||||
[2]: https://github.com/zitadel/oidc/discussions/458
|
||||
[3]: https://github.com/zitadel/oidc/tree/main
|
||||
[4]: https://github.com/zitadel/oidc/tree/next
|
||||
[5]: https://github.com/zitadel/oidc/milestone/2
|
||||
|
||||
## Reporting a vulnerability
|
||||
|
||||
To file a incident, please disclose by email to security@zitadel.com with the security details.
|
||||
|
||||
At the moment GPG encryption is no yet supported, however you may sign your message at will.
|
||||
|
||||
### When should I report a vulnerability
|
||||
|
||||
* You think you discovered a ...
|
||||
* ... potential security vulnerability in the SDK
|
||||
* ... vulnerability in another project that this SDK bases on
|
||||
* For projects with their own vulnerability reporting and disclosure process, please report it directly there
|
||||
|
||||
### When should I NOT report a vulnerability
|
||||
|
||||
* You need help applying security related updates
|
||||
* Your issue is not security related
|
||||
|
||||
## Security Vulnerability Response
|
||||
|
||||
TBD
|
||||
|
||||
## Public Disclosure
|
||||
|
||||
All accepted and mitigated vulnerabilities will be published on the [Github Security Page](https://github.com/zitadel/oidc/security/advisories)
|
||||
|
||||
### Timing
|
||||
|
||||
We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the disclosures the time frame can range from 7 to 90 days.
|
||||
|
|
|
@ -1,3 +1,37 @@
|
|||
// Command device is an example Oauth2 Device Authorization Grant app.
|
||||
// It creates a new Device Authorization request on the Issuer and then polls for tokens.
|
||||
// The user is then prompted to visit a URL and enter the user code.
|
||||
// Or, the complete URL can be used instead to omit manual entry.
|
||||
// In practice then can be a "magic link" in the form or a QR.
|
||||
//
|
||||
// The following environment variables are used for configuration:
|
||||
//
|
||||
// ISSUER: URL to the OP, required.
|
||||
// CLIENT_ID: ID of the application, required.
|
||||
// CLIENT_SECRET: Secret to authenticate the app using basic auth. Only required if the OP expects this type of authentication.
|
||||
// KEY_PATH: Path to a private key file, used to for JWT authentication of the App. Only required if the OP expects this type of authentication.
|
||||
// SCOPES: Scopes of the Authentication Request. Optional.
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// cd example/client/device
|
||||
// export ISSUER="http://localhost:9000" CLIENT_ID="246048465824634593@demo"
|
||||
//
|
||||
// Get an Access Token:
|
||||
//
|
||||
// SCOPES="email profile" go run .
|
||||
//
|
||||
// Get an Access Token and ID Token:
|
||||
//
|
||||
// SCOPES="email profile openid" go run .
|
||||
//
|
||||
// Get an Access Token and Refresh Token
|
||||
//
|
||||
// SCOPES="email profile offline_access" go run .
|
||||
//
|
||||
// Get Access, Refresh and ID Tokens:
|
||||
//
|
||||
// SCOPES="email profile offline_access openid" go run .
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@ -57,5 +91,5 @@ func main() {
|
|||
if err != nil {
|
||||
logrus.Fatal(err)
|
||||
}
|
||||
logrus.Infof("successfully obtained token: %v", token)
|
||||
logrus.Infof("successfully obtained token: %#v", token)
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ func SetupServer(issuer string, storage Storage, logger *slog.Logger, wrapServer
|
|||
|
||||
handler := http.Handler(provider)
|
||||
if wrapServer {
|
||||
handler = op.NewLegacyServer(provider, *op.DefaultEndpoints)
|
||||
handler = op.RegisterLegacyServer(op.NewLegacyServer(provider, *op.DefaultEndpoints))
|
||||
}
|
||||
|
||||
// we register the http handler of the OP on the root, so that the discovery endpoint (/.well-known/openid-configuration)
|
||||
|
|
|
@ -90,6 +90,10 @@ func (s *publicKey) Key() any {
|
|||
}
|
||||
|
||||
func NewStorage(userStore UserStore) *Storage {
|
||||
return NewStorageWithClients(userStore, clients)
|
||||
}
|
||||
|
||||
func NewStorageWithClients(userStore UserStore, clients map[string]*Client) *Storage {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
return &Storage{
|
||||
authRequests: make(map[string]*AuthRequest),
|
||||
|
|
29
go.mod
29
go.mod
|
@ -3,38 +3,39 @@ module github.com/zitadel/oidc/v3
|
|||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.0.10
|
||||
github.com/go-jose/go-jose/v3 v3.0.0
|
||||
github.com/bmatcuk/doublestar/v4 v4.6.1
|
||||
github.com/go-chi/chi/v5 v5.0.12
|
||||
github.com/go-jose/go-jose/v3 v3.0.2
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-github/v31 v31.0.0
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/gorilla/securecookie v1.1.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/securecookie v1.1.2
|
||||
github.com/jeremija/gosubmit v0.2.7
|
||||
github.com/muhlemmer/gu v0.3.1
|
||||
github.com/muhlemmer/httpforwarded v0.1.0
|
||||
github.com/rs/cors v1.10.1
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/zitadel/logging v0.4.0
|
||||
github.com/zitadel/logging v0.5.0
|
||||
github.com/zitadel/schema v1.3.0
|
||||
go.opentelemetry.io/otel v1.19.0
|
||||
go.opentelemetry.io/otel/trace v1.19.0
|
||||
go.opentelemetry.io/otel v1.24.0
|
||||
go.opentelemetry.io/otel/trace v1.24.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/oauth2 v0.13.0
|
||||
golang.org/x/text v0.13.0
|
||||
golang.org/x/oauth2 v0.17.0
|
||||
golang.org/x/text v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.19.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.19.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
89
go.sum
89
go.sum
|
@ -1,13 +1,15 @@
|
|||
github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I=
|
||||
github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk=
|
||||
github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
|
||||
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
|
||||
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
|
||||
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-jose/go-jose/v3 v3.0.2 h1:2Edjn8Nrb44UvTdp84KU0bBPs1cO7noRCybtS3eJEUQ=
|
||||
github.com/go-jose/go-jose/v3 v3.0.2/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
|
||||
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
|
@ -17,19 +19,20 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
|||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-github/v31 v31.0.0 h1:JJUxlP9lFK+ziXKimTCprajMApV1ecWD4NB6CCb0plo=
|
||||
github.com/google/go-github/v31 v31.0.0/go.mod h1:NQPZol8/1sMoWYGN2yaALIBytu17gAWfhbweiEed3pM=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
|
||||
github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
|
@ -45,59 +48,81 @@ github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU
|
|||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/zitadel/logging v0.4.0 h1:lRAIFgaRoJpLNbsL7jtIYHcMDoEJP9QZB4GqMfl4xaA=
|
||||
github.com/zitadel/logging v0.4.0/go.mod h1:6uALRJawpkkuUPCkgzfgcPR3c2N908wqnOnIrRelUFc=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA=
|
||||
github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE=
|
||||
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=
|
||||
github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc=
|
||||
go.opentelemetry.io/otel v1.19.0 h1:MuS/TNf4/j4IXsZuJegVzI1cwut7Qc00344rgH7p8bs=
|
||||
go.opentelemetry.io/otel v1.19.0/go.mod h1:i0QyjOq3UPoTzff0PJB2N66fb4S0+rSbSB15/oyH9fY=
|
||||
go.opentelemetry.io/otel/metric v1.19.0 h1:aTzpGtV0ar9wlV4Sna9sdJyII5jTVJEvKETPiOKwvpE=
|
||||
go.opentelemetry.io/otel/metric v1.19.0/go.mod h1:L5rUsV9kM1IxCj1MmSdS+JQAcVm319EUrDVLrt7jqt8=
|
||||
go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1Dzxpg=
|
||||
go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo=
|
||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||
go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
|
||||
go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
|
||||
go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
|
||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ=
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY=
|
||||
golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0=
|
||||
golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ=
|
||||
golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/oidc/v3/example/server/exampleop"
|
||||
"github.com/zitadel/oidc/v3/example/server/storage"
|
||||
|
@ -217,6 +218,7 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
targetURL,
|
||||
[]string{"openid", "email", "profile", "offline_access"},
|
||||
rp.WithPKCE(cookieHandler),
|
||||
rp.WithAuthStyle(oauth2.AuthStyleInHeader),
|
||||
rp.WithVerifierOpts(
|
||||
rp.WithIssuedAtOffset(5*time.Second),
|
||||
rp.WithSupportedSigningAlgorithms("RS256", "RS384", "RS512", "ES256", "ES384", "ES512"),
|
||||
|
@ -323,6 +325,31 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID,
|
|||
return provider, tokens
|
||||
}
|
||||
|
||||
func TestClientCredentials(t *testing.T) {
|
||||
targetURL := "http://local-site"
|
||||
exampleStorage := storage.NewStorage(storage.NewUserStore(targetURL))
|
||||
var dh deferredHandler
|
||||
opServer := httptest.NewServer(&dh)
|
||||
defer opServer.Close()
|
||||
t.Logf("auth server at %s", opServer.URL)
|
||||
dh.Handler = exampleop.SetupServer(opServer.URL, exampleStorage, Logger, true)
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(
|
||||
CTX,
|
||||
opServer.URL,
|
||||
"sid1",
|
||||
"verysecret",
|
||||
targetURL,
|
||||
[]string{"openid"},
|
||||
)
|
||||
require.NoError(t, err, "new rp")
|
||||
|
||||
token, err := rp.ClientCredentials(CTX, provider, nil)
|
||||
require.NoError(t, err, "ClientCredentials call")
|
||||
require.NotNil(t, token)
|
||||
assert.NotEmpty(t, token.AccessToken)
|
||||
}
|
||||
|
||||
func TestErrorFromPromptNone(t *testing.T) {
|
||||
jar, err := cookiejar.New(nil)
|
||||
require.NoError(t, err, "create cookie jar")
|
||||
|
|
5
pkg/client/rp/errors.go
Normal file
5
pkg/client/rp/errors.go
Normal file
|
@ -0,0 +1,5 @@
|
|||
package rp
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrRelyingPartyNotSupportRevokeCaller = errors.New("RelyingParty does not support RevokeCaller")
|
|
@ -4,16 +4,16 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/slog"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
|
@ -66,19 +66,28 @@ type RelyingParty interface {
|
|||
|
||||
// IDTokenVerifier returns the verifier used for oidc id_token verification
|
||||
IDTokenVerifier() *IDTokenVerifier
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
|
||||
// ErrorHandler returns the handler used for callback errors
|
||||
ErrorHandler() func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
// Logger from the context, or a fallback if set.
|
||||
Logger(context.Context) (logger *slog.Logger, ok bool)
|
||||
}
|
||||
|
||||
type HasUnauthorizedHandler interface {
|
||||
// UnauthorizedHandler returns the handler used for unauthorized errors
|
||||
UnauthorizedHandler() func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||
}
|
||||
|
||||
type ErrorHandler func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string)
|
||||
type UnauthorizedHandler func(w http.ResponseWriter, r *http.Request, desc string, state string)
|
||||
|
||||
var DefaultErrorHandler ErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
||||
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
|
||||
}
|
||||
var DefaultUnauthorizedHandler UnauthorizedHandler = func(w http.ResponseWriter, r *http.Request, desc string, state string) {
|
||||
http.Error(w, desc, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
type relyingParty struct {
|
||||
issuer string
|
||||
|
@ -91,11 +100,14 @@ type relyingParty struct {
|
|||
httpClient *http.Client
|
||||
cookieHandler *httphelper.CookieHandler
|
||||
|
||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
idTokenVerifier *IDTokenVerifier
|
||||
verifierOpts []VerifierOption
|
||||
signer jose.Signer
|
||||
logger *slog.Logger
|
||||
oauthAuthStyle oauth2.AuthStyle
|
||||
|
||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
unauthorizedHandler func(http.ResponseWriter, *http.Request, string, string)
|
||||
idTokenVerifier *IDTokenVerifier
|
||||
verifierOpts []VerifierOption
|
||||
signer jose.Signer
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (rp *relyingParty) OAuthConfig() *oauth2.Config {
|
||||
|
@ -156,6 +168,13 @@ func (rp *relyingParty) ErrorHandler() func(http.ResponseWriter, *http.Request,
|
|||
return rp.errorHandler
|
||||
}
|
||||
|
||||
func (rp *relyingParty) UnauthorizedHandler() func(http.ResponseWriter, *http.Request, string, string) {
|
||||
if rp.unauthorizedHandler == nil {
|
||||
rp.unauthorizedHandler = DefaultUnauthorizedHandler
|
||||
}
|
||||
return rp.unauthorizedHandler
|
||||
}
|
||||
|
||||
func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok bool) {
|
||||
logger, ok = logging.FromContext(ctx)
|
||||
if ok {
|
||||
|
@ -169,9 +188,11 @@ func (rp *relyingParty) Logger(ctx context.Context) (logger *slog.Logger, ok boo
|
|||
// it will use the AuthURL and TokenURL set in config
|
||||
func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingParty, error) {
|
||||
rp := &relyingParty{
|
||||
oauthConfig: config,
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
oauth2Only: true,
|
||||
oauthConfig: config,
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
oauth2Only: true,
|
||||
unauthorizedHandler: DefaultUnauthorizedHandler,
|
||||
oauthAuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
}
|
||||
|
||||
for _, optFunc := range options {
|
||||
|
@ -180,9 +201,12 @@ func NewRelyingPartyOAuth(config *oauth2.Config, options ...Option) (RelyingPart
|
|||
}
|
||||
}
|
||||
|
||||
rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle
|
||||
|
||||
// avoid races by calling these early
|
||||
_ = rp.IDTokenVerifier() // sets idTokenVerifier
|
||||
_ = rp.ErrorHandler() // sets errorHandler
|
||||
_ = rp.IDTokenVerifier() // sets idTokenVerifier
|
||||
_ = rp.ErrorHandler() // sets errorHandler
|
||||
_ = rp.UnauthorizedHandler() // sets unauthorizedHandler
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
@ -199,8 +223,9 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re
|
|||
RedirectURL: redirectURI,
|
||||
Scopes: scopes,
|
||||
},
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
oauth2Only: false,
|
||||
httpClient: httphelper.DefaultHTTPClient,
|
||||
oauth2Only: false,
|
||||
oauthAuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
}
|
||||
|
||||
for _, optFunc := range options {
|
||||
|
@ -217,9 +242,13 @@ func NewRelyingPartyOIDC(ctx context.Context, issuer, clientID, clientSecret, re
|
|||
rp.oauthConfig.Endpoint = endpoints.Endpoint
|
||||
rp.endpoints = endpoints
|
||||
|
||||
rp.oauthConfig.Endpoint.AuthStyle = rp.oauthAuthStyle
|
||||
rp.endpoints.Endpoint.AuthStyle = rp.oauthAuthStyle
|
||||
|
||||
// avoid races by calling these early
|
||||
_ = rp.IDTokenVerifier() // sets idTokenVerifier
|
||||
_ = rp.ErrorHandler() // sets errorHandler
|
||||
_ = rp.IDTokenVerifier() // sets idTokenVerifier
|
||||
_ = rp.ErrorHandler() // sets errorHandler
|
||||
_ = rp.UnauthorizedHandler() // sets unauthorizedHandler
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
@ -268,6 +297,20 @@ func WithErrorHandler(errorHandler ErrorHandler) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithUnauthorizedHandler(unauthorizedHandler UnauthorizedHandler) Option {
|
||||
return func(rp *relyingParty) error {
|
||||
rp.unauthorizedHandler = unauthorizedHandler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithAuthStyle(oauthAuthStyle oauth2.AuthStyle) Option {
|
||||
return func(rp *relyingParty) error {
|
||||
rp.oauthAuthStyle = oauthAuthStyle
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithVerifierOpts(opts ...VerifierOption) Option {
|
||||
return func(rp *relyingParty) error {
|
||||
rp.verifierOpts = opts
|
||||
|
@ -355,13 +398,13 @@ func AuthURLHandler(stateFn func() string, rp RelyingParty, urlParam ...URLParam
|
|||
|
||||
state := stateFn()
|
||||
if err := trySetStateCookie(w, state, rp); err != nil {
|
||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to create state cookie: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
if rp.IsPKCE() {
|
||||
codeChallenge, err := GenerateAndStoreCodeChallenge(w, rp)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to create code challenge: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
|
@ -416,17 +459,39 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP
|
|||
return verifyTokenResponse[C](ctx, token, rp)
|
||||
}
|
||||
|
||||
// ClientCredentials requests an access token using the `client_credentials` grant,
|
||||
// as defined in [RFC 6749, section 4.4].
|
||||
//
|
||||
// As there is no user associated to the request an ID Token can never be returned.
|
||||
// Client Credentials are undefined in OpenID Connect and is a pure OAuth2 grant.
|
||||
// Furthermore the server SHOULD NOT return a refresh token.
|
||||
//
|
||||
// [RFC 6749, section 4.4]: https://datatracker.ietf.org/doc/html/rfc6749#section-4.4
|
||||
func ClientCredentials(ctx context.Context, rp RelyingParty, endpointParams url.Values) (token *oauth2.Token, err error) {
|
||||
ctx = logCtxWithRPData(ctx, rp, "function", "ClientCredentials")
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, rp.HttpClient())
|
||||
config := clientcredentials.Config{
|
||||
ClientID: rp.OAuthConfig().ClientID,
|
||||
ClientSecret: rp.OAuthConfig().ClientSecret,
|
||||
TokenURL: rp.OAuthConfig().Endpoint.TokenURL,
|
||||
Scopes: rp.OAuthConfig().Scopes,
|
||||
EndpointParams: endpointParams,
|
||||
AuthStyle: rp.OAuthConfig().Endpoint.AuthStyle,
|
||||
}
|
||||
return config.Token(ctx)
|
||||
}
|
||||
|
||||
type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty)
|
||||
|
||||
// CodeExchangeHandler extends the `CodeExchange` method with a http handler
|
||||
// including cookie handling for secure `state` transfer
|
||||
// and optional PKCE code verifier checking.
|
||||
// Custom paramaters can optionally be set to the token URL.
|
||||
// Custom parameters can optionally be set to the token URL.
|
||||
func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp RelyingParty, urlParam ...URLParamOpt) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := tryReadStateCookie(w, r, rp)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to get state: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
params := r.URL.Query()
|
||||
|
@ -442,7 +507,7 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
if rp.IsPKCE() {
|
||||
codeVerifier, err := rp.CookieHandler().CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to get code verifier: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
||||
|
@ -451,14 +516,14 @@ func CodeExchangeHandler[C oidc.IDClaims](callback CodeExchangeCallback[C], rp R
|
|||
if rp.Signer() != nil {
|
||||
assertion, err := client.SignedJWTProfileAssertion(rp.OAuthConfig().ClientID, []string{rp.Issuer()}, time.Hour, rp.Signer())
|
||||
if err != nil {
|
||||
http.Error(w, "failed to build assertion: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to build assertion: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithClientAssertionJWT(assertion))
|
||||
}
|
||||
tokens, err := CodeExchange[C](r.Context(), params.Get("code"), rp, codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "failed to exchange token: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
callback(w, r, tokens, state, rp)
|
||||
|
@ -478,7 +543,7 @@ func UserinfoCallback[C oidc.IDClaims, U SubjectGetter](f CodeExchangeUserinfoCa
|
|||
return func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) {
|
||||
info, err := Userinfo[U](r.Context(), tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.GetSubject(), rp)
|
||||
if err != nil {
|
||||
http.Error(w, "userinfo failed: "+err.Error(), http.StatusUnauthorized)
|
||||
unauthorizedError(w, r, "userinfo failed: "+err.Error(), state, rp)
|
||||
return
|
||||
}
|
||||
f(w, r, tokens, state, rp, info)
|
||||
|
@ -545,9 +610,8 @@ type Endpoints struct {
|
|||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
return Endpoints{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: discoveryConfig.AuthorizationEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
AuthURL: discoveryConfig.AuthorizationEndpoint,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
|
@ -703,5 +767,13 @@ func RevokeToken(ctx context.Context, rp RelyingParty, token string, tokenTypeHi
|
|||
if rc, ok := rp.(client.RevokeCaller); ok && rc.GetRevokeEndpoint() != "" {
|
||||
return client.CallRevokeEndpoint(ctx, request, nil, rc)
|
||||
}
|
||||
return fmt.Errorf("RelyingParty does not support RevokeCaller")
|
||||
return ErrRelyingPartyNotSupportRevokeCaller
|
||||
}
|
||||
|
||||
func unauthorizedError(w http.ResponseWriter, r *http.Request, desc string, state string, rp RelyingParty) {
|
||||
if rp, ok := rp.(HasUnauthorizedHandler); ok {
|
||||
rp.UnauthorizedHandler()(w, r, desc, state)
|
||||
return
|
||||
}
|
||||
http.Error(w, desc, http.StatusUnauthorized)
|
||||
}
|
||||
|
|
|
@ -4,14 +4,19 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(priv)
|
||||
b := block.Bytes
|
||||
key, err := x509.ParsePKCS1PrivateKey(b)
|
||||
func BytesToPrivateKey(b []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(b)
|
||||
if block == nil {
|
||||
return nil, errors.New("PEM decode failed")
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
|
62
pkg/crypto/key_test.go
Normal file
62
pkg/crypto/key_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package crypto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
)
|
||||
|
||||
func TestBytesToPrivateKey(tt *testing.T) {
|
||||
tt.Run("PEMDecodeError", func(t *testing.T) {
|
||||
_, err := crypto.BytesToPrivateKey([]byte("The non-PEM sequence"))
|
||||
assert.EqualError(t, err, "PEM decode failed")
|
||||
})
|
||||
|
||||
tt.Run("InvalidKeyFormat", func(t *testing.T) {
|
||||
_, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN PRIVATE KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCfaDB7pK/fmP/I
|
||||
7IusSK8lTCBnPZghqIbVLt2QHYAMoEF1CaF4F4rxo2vl1Mt8gwsq4T3osQFZMvnL
|
||||
YHb7KNyUoJgTjLxJQADv2u4Q3U38heAzK5Tp4ry4MCnuyJIqAPK1GiruwEq4zQrx
|
||||
+WzVix8otO37SuW9tzklqlNGMiAYBL0TBKHvS5XMbjP1idBMB8erMz29w/TVQnEB
|
||||
Kj0vCdZjrbVPKygptt5kcSrL5f4xCZwU+ufz7cp0GLwpRMJ+shG9YJJFBxb0itPF
|
||||
sy51vAyEtdBC7jgAU96ZVeQ06nryDq1D2EpoVMElqNyL46Jo3lnKbGquGKzXzQYU
|
||||
BN32/scDAgMBAAECggEBAJE/mo3PLgILo2YtQ8ekIxNVHmF0Gl7w9IrjvTdH6hmX
|
||||
HI3MTLjkmtI7GmG9V/0IWvCjdInGX3grnrjWGRQZ04QKIQgPQLFuBGyJjEsJm7nx
|
||||
MqztlS7YTyV1nX/aenSTkJO8WEpcJLnm+4YoxCaAMdAhrIdBY71OamALpv1bRysa
|
||||
FaiCGcemT2yqZn0GqIS8O26Tz5zIqrTN2G1eSmgh7DG+7FoddMz35cute8R10xUG
|
||||
hF5YU+6fcXiRQ/Kh7nlxelPGqdZFPMk7LpVHzkQKwdJ+N0P23lPDIfNsvpG1n0OP
|
||||
3g5km7gHSrSU2yZ3eFl6DB9x1IFNS9BaQQuSxYJtKwECgYEA1C8jjzpXZDLvlYsV
|
||||
2jlMzkrbsIrX2dzblVrNsPs2jRbjYU8mg2DUDO6lOhtxHfqZG6sO+gmWi/zvoy9l
|
||||
yolGbXe1Jqx66p9fznIcecSwar8+ACa356Wk74Nt1PlBOfCMqaJnYLOLaFJa29Vy
|
||||
u5ClZVzKd5AVXl7yFVd4XfLv/WECgYEAwFMMtFoasdF92c0d31rZ1uoPOtFz6xq6
|
||||
uQggdm5zzkhnfwUAGqppS/u1CHcJ7T/74++jLbFTsaohGr4jEzWSGvJpomEUChy3
|
||||
r25YofMclUhJ5pCEStsLtqiCR1Am6LlI8HMdBEP1QDgEC5q8bQW4+UHuew1E1zxz
|
||||
osZOhe09WuMCgYEA0G9aFCnwjUqIFjQiDFP7gi8BLqTFs4uE3Wvs4W11whV42i+B
|
||||
ms90nxuTjchFT3jMDOT1+mOO0wdudLRr3xEI8SIF/u6ydGaJG+j21huEXehtxIJE
|
||||
aDdNFcfbDbqo+3y1ATK7MMBPMvSrsoY0hdJq127WqasNgr3sO1DIuima3SECgYEA
|
||||
nkM5TyhekzlbIOHD1UsDu/D7+2DkzPE/+oePfyXBMl0unb3VqhvVbmuBO6gJiSx/
|
||||
8b//PdiQkMD5YPJaFrKcuoQFHVRZk0CyfzCEyzAts0K7XXpLAvZiGztriZeRjSz7
|
||||
srJnjF0H8oKmAY6hw+1Tm/n/b08p+RyL48TgVSE2vhUCgYA3BWpkD4PlCcn/FZsq
|
||||
OrLFyFXI6jIaxskFtsRW1IxxIlAdZmxfB26P/2gx6VjLdxJI/RRPkJyEN2dP7CbR
|
||||
BDjb565dy1O9D6+UrY70Iuwjz+OcALRBBGTaiF2pLn6IhSzNI2sy/tXX8q8dBlg9
|
||||
OFCrqT/emes3KytTPfa5NZtYeQ==
|
||||
-----END PRIVATE KEY-----`))
|
||||
assert.EqualError(t, err, "x509: failed to parse private key (use ParsePKCS8PrivateKey instead for this key format)")
|
||||
})
|
||||
|
||||
tt.Run("Ok", func(t *testing.T) {
|
||||
key, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu
|
||||
KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm
|
||||
o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k
|
||||
TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7
|
||||
9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy
|
||||
v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs
|
||||
/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00
|
||||
-----END RSA PRIVATE KEY-----`))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
})
|
||||
}
|
|
@ -3,6 +3,7 @@ package oidc
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -76,8 +77,23 @@ func (l *Locale) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(tag)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
// When [language.ValueError] is encountered, the containing tag will be set
|
||||
// to an empty value (language "und") and no error will be returned.
|
||||
// This state can be checked with the `l.Tag().IsRoot()` method.
|
||||
func (l *Locale) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &l.tag)
|
||||
err := json.Unmarshal(data, &l.tag)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// catch "well-formed but unknown" errors
|
||||
var target language.ValueError
|
||||
if errors.As(err, &target) {
|
||||
l.tag = language.Tag{}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type Locales []language.Tag
|
||||
|
|
|
@ -208,20 +208,46 @@ func TestLocale_MarshalJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLocale_UnmarshalJSON(t *testing.T) {
|
||||
type a struct {
|
||||
type dst struct {
|
||||
Locale *Locale `json:"locale,omitempty"`
|
||||
}
|
||||
want := a{
|
||||
Locale: NewLocale(language.Afrikaans),
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want dst
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "afrikaans, ok",
|
||||
input: `{"locale": "af"}`,
|
||||
want: dst{
|
||||
Locale: NewLocale(language.Afrikaans),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gb, ignored",
|
||||
input: `{"locale": "gb"}`,
|
||||
want: dst{
|
||||
Locale: &Locale{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad form, error",
|
||||
input: `{"locale": "g!!!!!"}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
const input = `{"locale": "af"}`
|
||||
var got a
|
||||
|
||||
require.NoError(t,
|
||||
json.Unmarshal([]byte(input), &got),
|
||||
)
|
||||
assert.Equal(t, want, got)
|
||||
for _, tt := range tests {
|
||||
var got dst
|
||||
err := json.Unmarshal([]byte(tt.input), &got)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLocales(t *testing.T) {
|
||||
|
|
|
@ -57,7 +57,7 @@ var (
|
|||
ErrNonceInvalid = errors.New("nonce does not match")
|
||||
ErrAcrInvalid = errors.New("acr is invalid")
|
||||
ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing")
|
||||
ErrAuthTimeToOld = errors.New("auth time of token is to old")
|
||||
ErrAuthTimeToOld = errors.New("auth time of token is too old")
|
||||
ErrAtHash = errors.New("at_hash does not correspond to access token")
|
||||
)
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
str "github.com/zitadel/oidc/v3/pkg/strings"
|
||||
|
@ -138,20 +138,20 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage
|
|||
}
|
||||
|
||||
if requestObject.ClientID != "" && requestObject.ClientID != authReq.ClientID {
|
||||
return oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest().WithDescription("missing or wrong client id in request")
|
||||
}
|
||||
if requestObject.ResponseType != "" && requestObject.ResponseType != authReq.ResponseType {
|
||||
return oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest().WithDescription("missing or wrong response type in request")
|
||||
}
|
||||
if requestObject.Issuer != requestObject.ClientID {
|
||||
return oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request")
|
||||
}
|
||||
if !str.Contains(requestObject.Audience, issuer) {
|
||||
return oidc.ErrInvalidRequest()
|
||||
return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience")
|
||||
}
|
||||
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
|
||||
if err = oidc.CheckSignature(ctx, authReq.RequestParam, payload, requestObject, nil, keySet); err != nil {
|
||||
return err
|
||||
return oidc.ErrInvalidRequest().WithParent(err).WithDescription(err.Error())
|
||||
}
|
||||
CopyRequestObjectToAuthRequest(authReq, requestObject)
|
||||
return nil
|
||||
|
@ -283,7 +283,7 @@ func checkURIAgainstRedirects(client Client, uri string) error {
|
|||
}
|
||||
if globClient, ok := client.(HasRedirectGlobs); ok {
|
||||
for _, uriGlob := range globClient.RedirectURIGlobs() {
|
||||
isMatch, err := path.Match(uriGlob, uri)
|
||||
isMatch, err := doublestar.Match(uriGlob, uri)
|
||||
if err != nil {
|
||||
return oidc.ErrServerError().WithParent(err)
|
||||
}
|
||||
|
@ -391,9 +391,9 @@ func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifie
|
|||
return "", nil
|
||||
}
|
||||
claims, err := VerifyIDTokenHint[*oidc.TokenClaims](ctx, idTokenHint, verifier)
|
||||
if err != nil {
|
||||
if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
|
||||
return "", oidc.ErrLoginRequired().WithDescription("The id_token_hint is invalid. " +
|
||||
"If you have any questions, you may contact the administrator of the application.")
|
||||
"If you have any questions, you may contact the administrator of the application.").WithParent(err)
|
||||
}
|
||||
return claims.GetSubject(), nil
|
||||
}
|
||||
|
|
|
@ -583,6 +583,60 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
|
|||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs regular ok",
|
||||
args{
|
||||
"http://registered.com/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs wildcard ok",
|
||||
args{
|
||||
"http://registered.com/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://registered.com/*"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs double star ok",
|
||||
args{
|
||||
"http://registered.com/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs double star ok",
|
||||
args{
|
||||
"http://registered.com/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/*"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs IPv6 ok",
|
||||
args{
|
||||
"http://[::1]:80/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://\\[::1\\]:80/*"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow dev mode has redirect globs bad pattern",
|
||||
args{
|
||||
"http://registered.com/callback",
|
||||
mock.NewHasRedirectGlobsWithConfig(t, []string{"http://**/\\"}, op.ApplicationTypeUserAgent, nil, true),
|
||||
oidc.ResponseTypeCode,
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -63,6 +63,7 @@ type Client interface {
|
|||
// such as DevMode for the client being enabled.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
type HasRedirectGlobs interface {
|
||||
Client
|
||||
RedirectURIGlobs() []string
|
||||
PostLogoutRedirectURIGlobs() []string
|
||||
}
|
||||
|
|
|
@ -54,7 +54,24 @@ type Configuration interface {
|
|||
type IssuerFromRequest func(r *http.Request) string
|
||||
|
||||
func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
|
||||
return issuerFromForwardedOrHost(path, false)
|
||||
return issuerFromForwardedOrHost(path, new(issuerConfig))
|
||||
}
|
||||
|
||||
type IssuerFromOption func(c *issuerConfig)
|
||||
|
||||
// WithIssuerFromCustomHeaders can be used to customize the header names used.
|
||||
// The same rules apply where the first successful host is returned.
|
||||
func WithIssuerFromCustomHeaders(headers ...string) IssuerFromOption {
|
||||
return func(c *issuerConfig) {
|
||||
for i, h := range headers {
|
||||
headers[i] = http.CanonicalHeaderKey(h)
|
||||
}
|
||||
c.headers = headers
|
||||
}
|
||||
}
|
||||
|
||||
type issuerConfig struct {
|
||||
headers []string
|
||||
}
|
||||
|
||||
// IssuerFromForwardedOrHost tries to establish the Issuer based
|
||||
|
@ -64,11 +81,18 @@ func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
|
|||
// If the Forwarded header is not present, no host field is found,
|
||||
// or there is a parser error the Request Host will be used as a fallback.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Forwarded
|
||||
func IssuerFromForwardedOrHost(path string) func(bool) (IssuerFromRequest, error) {
|
||||
return issuerFromForwardedOrHost(path, true)
|
||||
func IssuerFromForwardedOrHost(path string, opts ...IssuerFromOption) func(bool) (IssuerFromRequest, error) {
|
||||
c := &issuerConfig{
|
||||
headers: []string{http.CanonicalHeaderKey("forwarded")},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
|
||||
return issuerFromForwardedOrHost(path, c)
|
||||
}
|
||||
|
||||
func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (IssuerFromRequest, error) {
|
||||
func issuerFromForwardedOrHost(path string, c *issuerConfig) func(bool) (IssuerFromRequest, error) {
|
||||
return func(allowInsecure bool) (IssuerFromRequest, error) {
|
||||
issuerPath, err := url.Parse(path)
|
||||
if err != nil {
|
||||
|
@ -78,26 +102,26 @@ func issuerFromForwardedOrHost(path string, parseForwarded bool) func(bool) (Iss
|
|||
return nil, err
|
||||
}
|
||||
return func(r *http.Request) string {
|
||||
if parseForwarded {
|
||||
if host, ok := hostFromForwarded(r); ok {
|
||||
return dynamicIssuer(host, path, allowInsecure)
|
||||
}
|
||||
if host, ok := hostFromForwarded(r, c.headers); ok {
|
||||
return dynamicIssuer(host, path, allowInsecure)
|
||||
}
|
||||
return dynamicIssuer(r.Host, path, allowInsecure)
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func hostFromForwarded(r *http.Request) (host string, ok bool) {
|
||||
fwd, err := httpforwarded.ParseFromRequest(r)
|
||||
if err != nil {
|
||||
log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch
|
||||
return "", false
|
||||
func hostFromForwarded(r *http.Request, headers []string) (host string, ok bool) {
|
||||
for _, header := range headers {
|
||||
hosts, err := httpforwarded.ParseParameter("host", r.Header[header])
|
||||
if err != nil {
|
||||
log.Printf("Err: issuer from forwarded header: %v", err) // TODO change to slog on next branch
|
||||
continue
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
return hosts[0], true
|
||||
}
|
||||
}
|
||||
if fwd == nil || len(fwd["host"]) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return fwd["host"][0], true
|
||||
return "", false
|
||||
}
|
||||
|
||||
func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
@ -264,9 +265,10 @@ func TestIssuerFromHost(t *testing.T) {
|
|||
|
||||
func TestIssuerFromForwardedOrHost(t *testing.T) {
|
||||
type args struct {
|
||||
path string
|
||||
target string
|
||||
forwarded []string
|
||||
path string
|
||||
opts []IssuerFromOption
|
||||
target string
|
||||
header map[string][]string
|
||||
}
|
||||
type res struct {
|
||||
issuer string
|
||||
|
@ -279,9 +281,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) {
|
|||
{
|
||||
"header parse error",
|
||||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
forwarded: []string{"~~~"},
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
header: map[string][]string{"Forwarded": {"~~~~"}},
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
|
@ -303,9 +305,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) {
|
|||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
forwarded: []string{
|
||||
header: map[string][]string{"Forwarded": {
|
||||
`by=identifier;for=identifier;proto=https`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
res{
|
||||
issuer: "https://issuer.com/custom/",
|
||||
|
@ -316,9 +318,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) {
|
|||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
forwarded: []string{
|
||||
header: map[string][]string{"Forwarded": {
|
||||
`by=identifier;for=identifier;host=first.com;proto=https`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
res{
|
||||
issuer: "https://first.com/custom/",
|
||||
|
@ -329,9 +331,9 @@ func TestIssuerFromForwardedOrHost(t *testing.T) {
|
|||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
forwarded: []string{
|
||||
header: map[string][]string{"Forwarded": {
|
||||
`by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
res{
|
||||
issuer: "https://first.com/custom/",
|
||||
|
@ -342,23 +344,45 @@ func TestIssuerFromForwardedOrHost(t *testing.T) {
|
|||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
forwarded: []string{
|
||||
header: map[string][]string{"Forwarded": {
|
||||
`by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
|
||||
`by=identifier;for=identifier;host=third.com;proto=https`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
res{
|
||||
issuer: "https://first.com/custom/",
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom header first",
|
||||
args{
|
||||
path: "/custom/",
|
||||
target: "https://issuer.com",
|
||||
header: map[string][]string{
|
||||
"Forwarded": {
|
||||
`by=identifier;for=identifier;host=first.com;proto=https,host=second.com`,
|
||||
`by=identifier;for=identifier;host=third.com;proto=https`,
|
||||
},
|
||||
"X-Custom-Forwarded": {
|
||||
`by=identifier;for=identifier;host=custom.com;proto=https,host=custom2.com`,
|
||||
},
|
||||
},
|
||||
opts: []IssuerFromOption{
|
||||
WithIssuerFromCustomHeaders("x-custom-forwarded"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
issuer: "https://custom.com/custom/",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
issuer, err := IssuerFromForwardedOrHost(tt.args.path)(false)
|
||||
issuer, err := IssuerFromForwardedOrHost(tt.args.path, tt.args.opts...)(false)
|
||||
require.NoError(t, err)
|
||||
req := httptest.NewRequest("", tt.args.target, nil)
|
||||
if tt.args.forwarded != nil {
|
||||
req.Header["Forwarded"] = tt.args.forwarded
|
||||
for k, v := range tt.args.header {
|
||||
req.Header[http.CanonicalHeaderKey(k)] = v
|
||||
}
|
||||
assert.Equal(t, tt.res.issuer, issuer(req))
|
||||
})
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
strs "github.com/zitadel/oidc/v3/pkg/strings"
|
||||
)
|
||||
|
||||
type DeviceAuthorizationConfig struct {
|
||||
|
@ -185,24 +186,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
|
|||
return buf.String(), nil
|
||||
}
|
||||
|
||||
type deviceAccessTokenRequest struct {
|
||||
subject string
|
||||
audience []string
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetSubject() string {
|
||||
return r.subject
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetAudience() []string {
|
||||
return r.audience
|
||||
}
|
||||
|
||||
func (r *deviceAccessTokenRequest) GetScopes() []string {
|
||||
return r.scopes
|
||||
}
|
||||
|
||||
func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
ctx, span := tracer.Start(r.Context(), "DeviceAccessToken")
|
||||
defer span.End()
|
||||
|
@ -229,7 +212,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||
tokenRequest, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -243,11 +226,6 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
WithDescription("confidential client requires authentication")
|
||||
}
|
||||
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{clientID},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -265,6 +243,50 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.
|
|||
return req, nil
|
||||
}
|
||||
|
||||
// DeviceAuthorizationState describes the current state of
|
||||
// the device authorization flow.
|
||||
// It implements the [IDTokenRequest] interface.
|
||||
type DeviceAuthorizationState struct {
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scopes []string
|
||||
Expires time.Time // The time after we consider the authorization request timed-out
|
||||
Done bool // The user authenticated and approved the authorization request
|
||||
Denied bool // The user authenticated and denied the authorization request
|
||||
|
||||
// The following fields are populated after Done == true
|
||||
Subject string
|
||||
AMR []string
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetAMR() []string {
|
||||
return r.AMR
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetAudience() []string {
|
||||
if !strs.Contains(r.Audience, r.ClientID) {
|
||||
r.Audience = append(r.Audience, r.ClientID)
|
||||
}
|
||||
return r.Audience
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetAuthTime() time.Time {
|
||||
return r.AuthTime
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetClientID() string {
|
||||
return r.ClientID
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetScopes() []string {
|
||||
return r.Scopes
|
||||
}
|
||||
|
||||
func (r *DeviceAuthorizationState) GetSubject() string {
|
||||
return r.Subject
|
||||
}
|
||||
|
||||
func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
|
||||
storage, err := assertDeviceStorage(exchanger.Storage())
|
||||
if err != nil {
|
||||
|
@ -291,15 +313,32 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str
|
|||
}
|
||||
|
||||
func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
|
||||
/* TODO(v4):
|
||||
Change the TokenRequest argument type to *DeviceAuthorizationState.
|
||||
Breaking change that can not be done for v3.
|
||||
*/
|
||||
ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse")
|
||||
defer span.End()
|
||||
|
||||
accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &oidc.AccessTokenResponse{
|
||||
response := &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: uint64(validity.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO(v4): remove type assertion
|
||||
if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && strs.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
|
||||
response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
|
|
@ -453,3 +453,96 @@ func TestCheckDeviceAuthorizationState(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDeviceTokenResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenRequest op.TokenRequest
|
||||
wantAccessToken bool
|
||||
wantRefreshToken bool
|
||||
wantIDToken bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "access token",
|
||||
tokenRequest: &op.DeviceAuthorizationState{
|
||||
ClientID: "client1",
|
||||
Subject: "id1",
|
||||
AMR: []string{"password"},
|
||||
AuthTime: time.Now(),
|
||||
},
|
||||
wantAccessToken: true,
|
||||
},
|
||||
{
|
||||
name: "access and refresh tokens",
|
||||
tokenRequest: &op.DeviceAuthorizationState{
|
||||
ClientID: "client1",
|
||||
Subject: "id1",
|
||||
AMR: []string{"password"},
|
||||
AuthTime: time.Now(),
|
||||
Scopes: []string{oidc.ScopeOfflineAccess},
|
||||
},
|
||||
wantAccessToken: true,
|
||||
wantRefreshToken: true,
|
||||
},
|
||||
{
|
||||
name: "access and id token",
|
||||
tokenRequest: &op.DeviceAuthorizationState{
|
||||
ClientID: "client1",
|
||||
Subject: "id1",
|
||||
AMR: []string{"password"},
|
||||
AuthTime: time.Now(),
|
||||
Scopes: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
wantAccessToken: true,
|
||||
wantIDToken: true,
|
||||
},
|
||||
{
|
||||
name: "access, refresh and id token",
|
||||
tokenRequest: &op.DeviceAuthorizationState{
|
||||
ClientID: "client1",
|
||||
Subject: "id1",
|
||||
AMR: []string{"password"},
|
||||
AuthTime: time.Now(),
|
||||
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
|
||||
},
|
||||
wantAccessToken: true,
|
||||
wantRefreshToken: true,
|
||||
wantIDToken: true,
|
||||
},
|
||||
{
|
||||
name: "id token creation error",
|
||||
tokenRequest: &op.DeviceAuthorizationState{
|
||||
ClientID: "client1",
|
||||
Subject: "foobar",
|
||||
AMR: []string{"password"},
|
||||
AuthTime: time.Now(),
|
||||
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, 300, got.ExpiresIn, 2)
|
||||
if tt.wantAccessToken {
|
||||
assert.NotEmpty(t, got.AccessToken, "access token")
|
||||
}
|
||||
if tt.wantRefreshToken {
|
||||
assert.NotEmpty(t, got.RefreshToken, "refresh token")
|
||||
}
|
||||
if tt.wantIDToken {
|
||||
assert.NotEmpty(t, got.IDToken, "id token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -213,32 +213,12 @@ func AuthMethodsRevocationEndpoint(c Configuration) []oidc.AuthMethod {
|
|||
}
|
||||
|
||||
func SupportedClaims(c Configuration) []string {
|
||||
return []string{ // TODO: config
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"c_hash",
|
||||
"at_hash",
|
||||
"act",
|
||||
"scopes",
|
||||
"client_id",
|
||||
"azp",
|
||||
"preferred_username",
|
||||
"name",
|
||||
"family_name",
|
||||
"given_name",
|
||||
"locale",
|
||||
"email",
|
||||
"email_verified",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
provider, ok := c.(*Provider)
|
||||
if ok && provider.config.SupportedClaims != nil {
|
||||
return provider.config.SupportedClaims
|
||||
}
|
||||
|
||||
return DefaultSupportedClaims
|
||||
}
|
||||
|
||||
func CodeChallengeMethods(c Configuration) []oidc.CodeChallengeMethod {
|
||||
|
|
|
@ -157,13 +157,29 @@ func (e StatusError) Is(err error) bool {
|
|||
e.statusCode == target.statusCode
|
||||
}
|
||||
|
||||
// WriteError asserts for a StatusError containing an [oidc.Error].
|
||||
// If no StatusError is found, the status code will default to [http.StatusBadRequest].
|
||||
// If no [oidc.Error] was found in the parent, the error type defaults to [oidc.ServerError].
|
||||
// WriteError asserts for a [StatusError] containing an [oidc.Error].
|
||||
// If no `StatusError` is found, the status code will default to [http.StatusBadRequest].
|
||||
// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError].
|
||||
// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`,
|
||||
// the status code will be set to [http.StatusInternalServerError]
|
||||
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
|
||||
statusError := AsStatusError(err, http.StatusBadRequest)
|
||||
e := oidc.DefaultToServerError(statusError.parent, statusError.parent.Error())
|
||||
|
||||
logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
|
||||
httphelper.MarshalJSONWithStatus(w, e, statusError.statusCode)
|
||||
var statusError StatusError
|
||||
if errors.As(err, &statusError) {
|
||||
writeError(w, r,
|
||||
oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()),
|
||||
statusError.statusCode, logger,
|
||||
)
|
||||
return
|
||||
}
|
||||
statusCode := http.StatusBadRequest
|
||||
e := oidc.DefaultToServerError(err, err.Error())
|
||||
if e.ErrorType == oidc.ServerError {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
writeError(w, r, e, statusCode, logger)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) {
|
||||
logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode)
|
||||
httphelper.MarshalJSONWithStatus(w, err, statusCode)
|
||||
}
|
||||
|
|
|
@ -579,7 +579,7 @@ func TestWriteError(t *testing.T) {
|
|||
{
|
||||
name: "not a status or oidc error",
|
||||
err: io.ErrClosedPipe,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{
|
||||
"error":"server_error",
|
||||
"error_description":"io: read/write on closed pipe"
|
||||
|
@ -592,6 +592,7 @@ func TestWriteError(t *testing.T) {
|
|||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"status_code":500,
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
|
@ -611,6 +612,7 @@ func TestWriteError(t *testing.T) {
|
|||
"parent":"io: read/write on closed pipe",
|
||||
"type":"server_error"
|
||||
},
|
||||
"status_code":500,
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
|
@ -629,6 +631,7 @@ func TestWriteError(t *testing.T) {
|
|||
"description":"oops",
|
||||
"type":"invalid_request"
|
||||
},
|
||||
"status_code":400,
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
|
@ -650,6 +653,7 @@ func TestWriteError(t *testing.T) {
|
|||
"description":"oops",
|
||||
"type":"unauthorized_client"
|
||||
},
|
||||
"status_code":401,
|
||||
"time":"not"
|
||||
}`,
|
||||
},
|
||||
|
|
|
@ -4,6 +4,7 @@ package mock
|
|||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/zitadel/oidc/v3/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/zitadel/oidc/v3/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/zitadel/oidc/v3/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./glob.mock.go github.com/zitadel/oidc/v3/pkg/op HasRedirectGlobs
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/zitadel/oidc/v3/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./discovery.mock.go github.com/zitadel/oidc/v3/pkg/op DiscoverStorage
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/zitadel/oidc/v3/pkg/op SigningKey,Key
|
||||
|
|
24
pkg/op/mock/glob.go
Normal file
24
pkg/op/mock/glob.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
func NewHasRedirectGlobs(t *testing.T) op.HasRedirectGlobs {
|
||||
return NewMockHasRedirectGlobs(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewHasRedirectGlobsWithConfig(t *testing.T, uri []string, appType op.ApplicationType, responseTypes []oidc.ResponseType, devMode bool) op.HasRedirectGlobs {
|
||||
c := NewHasRedirectGlobs(t)
|
||||
m := c.(*MockHasRedirectGlobs)
|
||||
m.EXPECT().RedirectURIs().AnyTimes().Return(uri)
|
||||
m.EXPECT().RedirectURIGlobs().AnyTimes().Return(uri)
|
||||
m.EXPECT().ApplicationType().AnyTimes().Return(appType)
|
||||
m.EXPECT().ResponseTypes().AnyTimes().Return(responseTypes)
|
||||
m.EXPECT().DevMode().AnyTimes().Return(devMode)
|
||||
return c
|
||||
}
|
289
pkg/op/mock/glob.mock.go
Normal file
289
pkg/op/mock/glob.mock.go
Normal file
|
@ -0,0 +1,289 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/oidc/v3/pkg/op (interfaces: HasRedirectGlobs)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
oidc "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
op "github.com/zitadel/oidc/v3/pkg/op"
|
||||
)
|
||||
|
||||
// MockHasRedirectGlobs is a mock of HasRedirectGlobs interface.
|
||||
type MockHasRedirectGlobs struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockHasRedirectGlobsMockRecorder
|
||||
}
|
||||
|
||||
// MockHasRedirectGlobsMockRecorder is the mock recorder for MockHasRedirectGlobs.
|
||||
type MockHasRedirectGlobsMockRecorder struct {
|
||||
mock *MockHasRedirectGlobs
|
||||
}
|
||||
|
||||
// NewMockHasRedirectGlobs creates a new mock instance.
|
||||
func NewMockHasRedirectGlobs(ctrl *gomock.Controller) *MockHasRedirectGlobs {
|
||||
mock := &MockHasRedirectGlobs{ctrl: ctrl}
|
||||
mock.recorder = &MockHasRedirectGlobsMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockHasRedirectGlobs) EXPECT() *MockHasRedirectGlobsMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AccessTokenType mocks base method.
|
||||
func (m *MockHasRedirectGlobs) AccessTokenType() op.AccessTokenType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenType")
|
||||
ret0, _ := ret[0].(op.AccessTokenType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenType indicates an expected call of AccessTokenType.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) AccessTokenType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AccessTokenType))
|
||||
}
|
||||
|
||||
// ApplicationType mocks base method.
|
||||
func (m *MockHasRedirectGlobs) ApplicationType() op.ApplicationType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ApplicationType")
|
||||
ret0, _ := ret[0].(op.ApplicationType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ApplicationType indicates an expected call of ApplicationType.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) ApplicationType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ApplicationType))
|
||||
}
|
||||
|
||||
// AuthMethod mocks base method.
|
||||
func (m *MockHasRedirectGlobs) AuthMethod() oidc.AuthMethod {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthMethod")
|
||||
ret0, _ := ret[0].(oidc.AuthMethod)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthMethod indicates an expected call of AuthMethod.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) AuthMethod() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethod", reflect.TypeOf((*MockHasRedirectGlobs)(nil).AuthMethod))
|
||||
}
|
||||
|
||||
// ClockSkew mocks base method.
|
||||
func (m *MockHasRedirectGlobs) ClockSkew() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClockSkew")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClockSkew indicates an expected call of ClockSkew.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) ClockSkew() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClockSkew", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ClockSkew))
|
||||
}
|
||||
|
||||
// DevMode mocks base method.
|
||||
func (m *MockHasRedirectGlobs) DevMode() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DevMode")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DevMode indicates an expected call of DevMode.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) DevMode() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DevMode", reflect.TypeOf((*MockHasRedirectGlobs)(nil).DevMode))
|
||||
}
|
||||
|
||||
// GetID mocks base method.
|
||||
func (m *MockHasRedirectGlobs) GetID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetID indicates an expected call of GetID.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) GetID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GetID))
|
||||
}
|
||||
|
||||
// GrantTypes mocks base method.
|
||||
func (m *MockHasRedirectGlobs) GrantTypes() []oidc.GrantType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GrantTypes")
|
||||
ret0, _ := ret[0].([]oidc.GrantType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GrantTypes indicates an expected call of GrantTypes.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) GrantTypes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GrantTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).GrantTypes))
|
||||
}
|
||||
|
||||
// IDTokenLifetime mocks base method.
|
||||
func (m *MockHasRedirectGlobs) IDTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenLifetime indicates an expected call of IDTokenLifetime.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) IDTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenLifetime))
|
||||
}
|
||||
|
||||
// IDTokenUserinfoClaimsAssertion mocks base method.
|
||||
func (m *MockHasRedirectGlobs) IDTokenUserinfoClaimsAssertion() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenUserinfoClaimsAssertion")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenUserinfoClaimsAssertion indicates an expected call of IDTokenUserinfoClaimsAssertion.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) IDTokenUserinfoClaimsAssertion() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenUserinfoClaimsAssertion", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IDTokenUserinfoClaimsAssertion))
|
||||
}
|
||||
|
||||
// IsScopeAllowed mocks base method.
|
||||
func (m *MockHasRedirectGlobs) IsScopeAllowed(arg0 string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsScopeAllowed", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsScopeAllowed indicates an expected call of IsScopeAllowed.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) IsScopeAllowed(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsScopeAllowed", reflect.TypeOf((*MockHasRedirectGlobs)(nil).IsScopeAllowed), arg0)
|
||||
}
|
||||
|
||||
// LoginURL mocks base method.
|
||||
func (m *MockHasRedirectGlobs) LoginURL(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginURL", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LoginURL indicates an expected call of LoginURL.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockHasRedirectGlobs)(nil).LoginURL), arg0)
|
||||
}
|
||||
|
||||
// PostLogoutRedirectURIGlobs mocks base method.
|
||||
func (m *MockHasRedirectGlobs) PostLogoutRedirectURIGlobs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PostLogoutRedirectURIGlobs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PostLogoutRedirectURIGlobs indicates an expected call of PostLogoutRedirectURIGlobs.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIGlobs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIGlobs))
|
||||
}
|
||||
|
||||
// PostLogoutRedirectURIs mocks base method.
|
||||
func (m *MockHasRedirectGlobs) PostLogoutRedirectURIs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PostLogoutRedirectURIs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PostLogoutRedirectURIs indicates an expected call of PostLogoutRedirectURIs.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) PostLogoutRedirectURIs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostLogoutRedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).PostLogoutRedirectURIs))
|
||||
}
|
||||
|
||||
// RedirectURIGlobs mocks base method.
|
||||
func (m *MockHasRedirectGlobs) RedirectURIGlobs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RedirectURIGlobs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RedirectURIGlobs indicates an expected call of RedirectURIGlobs.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIGlobs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIGlobs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIGlobs))
|
||||
}
|
||||
|
||||
// RedirectURIs mocks base method.
|
||||
func (m *MockHasRedirectGlobs) RedirectURIs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RedirectURIs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RedirectURIs indicates an expected call of RedirectURIs.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) RedirectURIs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RedirectURIs))
|
||||
}
|
||||
|
||||
// ResponseTypes mocks base method.
|
||||
func (m *MockHasRedirectGlobs) ResponseTypes() []oidc.ResponseType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResponseTypes")
|
||||
ret0, _ := ret[0].([]oidc.ResponseType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ResponseTypes indicates an expected call of ResponseTypes.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) ResponseTypes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseTypes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).ResponseTypes))
|
||||
}
|
||||
|
||||
// RestrictAdditionalAccessTokenScopes mocks base method.
|
||||
func (m *MockHasRedirectGlobs) RestrictAdditionalAccessTokenScopes() func([]string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RestrictAdditionalAccessTokenScopes")
|
||||
ret0, _ := ret[0].(func([]string) []string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RestrictAdditionalAccessTokenScopes indicates an expected call of RestrictAdditionalAccessTokenScopes.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalAccessTokenScopes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalAccessTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalAccessTokenScopes))
|
||||
}
|
||||
|
||||
// RestrictAdditionalIdTokenScopes mocks base method.
|
||||
func (m *MockHasRedirectGlobs) RestrictAdditionalIdTokenScopes() func([]string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RestrictAdditionalIdTokenScopes")
|
||||
ret0, _ := ret[0].(func([]string) []string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RestrictAdditionalIdTokenScopes indicates an expected call of RestrictAdditionalIdTokenScopes.
|
||||
func (mr *MockHasRedirectGlobsMockRecorder) RestrictAdditionalIdTokenScopes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestrictAdditionalIdTokenScopes", reflect.TypeOf((*MockHasRedirectGlobs)(nil).RestrictAdditionalIdTokenScopes))
|
||||
}
|
150
pkg/op/op.go
150
pkg/op/op.go
|
@ -45,6 +45,33 @@ var (
|
|||
DeviceAuthorization: NewEndpoint(defaultDeviceAuthzEndpoint),
|
||||
}
|
||||
|
||||
DefaultSupportedClaims = []string{
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"c_hash",
|
||||
"at_hash",
|
||||
"act",
|
||||
"scopes",
|
||||
"client_id",
|
||||
"azp",
|
||||
"preferred_username",
|
||||
"name",
|
||||
"family_name",
|
||||
"given_name",
|
||||
"locale",
|
||||
"email",
|
||||
"email_verified",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
}
|
||||
|
||||
defaultCORSOptions = cors.Options{
|
||||
AllowCredentials: true,
|
||||
AllowedHeaders: []string{
|
||||
|
@ -97,9 +124,19 @@ type OpenIDProvider interface {
|
|||
|
||||
type HttpInterceptor func(http.Handler) http.Handler
|
||||
|
||||
type corsOptioner interface {
|
||||
CORSOptions() *cors.Options
|
||||
}
|
||||
|
||||
func CreateRouter(o OpenIDProvider, interceptors ...HttpInterceptor) chi.Router {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
if co, ok := o.(corsOptioner); ok {
|
||||
if opts := co.CORSOptions(); opts != nil {
|
||||
router.Use(cors.New(*opts).Handler)
|
||||
}
|
||||
} else {
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
}
|
||||
router.Use(intercept(o.IssuerFromRequest, interceptors...))
|
||||
router.HandleFunc(healthEndpoint, healthHandler)
|
||||
router.HandleFunc(readinessEndpoint, readyHandler(o.Probes()))
|
||||
|
@ -136,6 +173,7 @@ type Config struct {
|
|||
GrantTypeRefreshToken bool
|
||||
RequestObjectSupported bool
|
||||
SupportedUILocales []language.Tag
|
||||
SupportedClaims []string
|
||||
DeviceAuthorization DeviceAuthorizationConfig
|
||||
}
|
||||
|
||||
|
@ -173,28 +211,62 @@ type Endpoints struct {
|
|||
// Successful logins should mark the request as authorized and redirect back to to
|
||||
// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
|
||||
// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
|
||||
//
|
||||
// Deprecated: use [NewProvider] with an issuer function direct.
|
||||
func NewOpenIDProvider(issuer string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
|
||||
return newProvider(config, storage, StaticIssuer(issuer), opOpts...)
|
||||
return NewProvider(config, storage, StaticIssuer(issuer), opOpts...)
|
||||
}
|
||||
|
||||
// NewForwardedOpenIDProvider tries to establishes the issuer from the request Host.
|
||||
//
|
||||
// Deprecated: use [NewProvider] with an issuer function direct.
|
||||
func NewDynamicOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
|
||||
return newProvider(config, storage, IssuerFromHost(path), opOpts...)
|
||||
return NewProvider(config, storage, IssuerFromHost(path), opOpts...)
|
||||
}
|
||||
|
||||
// NewForwardedOpenIDProvider tries to establish the Issuer from a Forwarded request header, if it is set.
|
||||
// See [IssuerFromForwardedOrHost] for details.
|
||||
//
|
||||
// Deprecated: use [NewProvider] with an issuer function direct.
|
||||
func NewForwardedOpenIDProvider(path string, config *Config, storage Storage, opOpts ...Option) (*Provider, error) {
|
||||
return newProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...)
|
||||
return NewProvider(config, storage, IssuerFromForwardedOrHost(path), opOpts...)
|
||||
}
|
||||
|
||||
func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
|
||||
// NewProvider creates a provider with a router on it's embedded http.Handler.
|
||||
// Issuer is a function that must return the issuer on every request.
|
||||
// Typically [StaticIssuer], [IssuerFromHost] or [IssuerFromForwardedOrHost] can be used.
|
||||
//
|
||||
// The router handles a suite of endpoints (some paths can be overridden):
|
||||
//
|
||||
// /healthz
|
||||
// /ready
|
||||
// /.well-known/openid-configuration
|
||||
// /oauth/token
|
||||
// /oauth/introspect
|
||||
// /callback
|
||||
// /authorize
|
||||
// /userinfo
|
||||
// /revoke
|
||||
// /end_session
|
||||
// /keys
|
||||
// /device_authorization
|
||||
//
|
||||
// This does not include login. Login is handled with a redirect that includes the
|
||||
// request ID. The redirect for logins is specified per-client by Client.LoginURL().
|
||||
// Successful logins should mark the request as authorized and redirect back to to
|
||||
// op.AuthCallbackURL(provider) which is probably /callback. On the redirect back
|
||||
// to the AuthCallbackURL, the request id should be passed as the "id" parameter.
|
||||
func NewProvider(config *Config, storage Storage, issuer func(insecure bool) (IssuerFromRequest, error), opOpts ...Option) (_ *Provider, err error) {
|
||||
keySet := &OpenIDKeySet{storage}
|
||||
o := &Provider{
|
||||
config: config,
|
||||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
timer: make(<-chan time.Time),
|
||||
logger: slog.Default(),
|
||||
config: config,
|
||||
storage: storage,
|
||||
accessTokenKeySet: keySet,
|
||||
idTokenHinKeySet: keySet,
|
||||
endpoints: DefaultEndpoints,
|
||||
timer: make(<-chan time.Time),
|
||||
corsOpts: &defaultCORSOptions,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
for _, optFunc := range opOpts {
|
||||
|
@ -207,19 +279,11 @@ func newProvider(config *Config, storage Storage, issuer func(bool) (IssuerFromR
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.Handler = CreateRouter(o, o.interceptors...)
|
||||
|
||||
o.decoder = schema.NewDecoder()
|
||||
o.decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
o.encoder = oidc.NewEncoder()
|
||||
|
||||
o.crypto = NewAESCrypto(config.CryptoKey)
|
||||
|
||||
// Avoid potential race conditions by calling these early
|
||||
_ = o.openIDKeySet() // sets keySet
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
|
@ -230,7 +294,8 @@ type Provider struct {
|
|||
insecure bool
|
||||
endpoints *Endpoints
|
||||
storage Storage
|
||||
keySet *openIDKeySet
|
||||
accessTokenKeySet oidc.KeySet
|
||||
idTokenHinKeySet oidc.KeySet
|
||||
crypto Crypto
|
||||
decoder *schema.Decoder
|
||||
encoder *schema.Encoder
|
||||
|
@ -238,6 +303,7 @@ type Provider struct {
|
|||
timer <-chan time.Time
|
||||
accessTokenVerifierOpts []AccessTokenVerifierOpt
|
||||
idTokenHintVerifierOpts []IDTokenHintVerifierOpt
|
||||
corsOpts *cors.Options
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
|
@ -365,7 +431,7 @@ func (o *Provider) Encoder() httphelper.Encoder {
|
|||
}
|
||||
|
||||
func (o *Provider) IDTokenHintVerifier(ctx context.Context) *IDTokenHintVerifier {
|
||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.idTokenHintVerifierOpts...)
|
||||
return NewIDTokenHintVerifier(IssuerFromContext(ctx), o.idTokenHinKeySet, o.idTokenHintVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
|
||||
|
@ -373,14 +439,7 @@ func (o *Provider) JWTProfileVerifier(ctx context.Context) *JWTProfileVerifier {
|
|||
}
|
||||
|
||||
func (o *Provider) AccessTokenVerifier(ctx context.Context) *AccessTokenVerifier {
|
||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.openIDKeySet(), o.accessTokenVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *Provider) openIDKeySet() oidc.KeySet {
|
||||
if o.keySet == nil {
|
||||
o.keySet = &openIDKeySet{o.Storage()}
|
||||
}
|
||||
return o.keySet
|
||||
return NewAccessTokenVerifier(IssuerFromContext(ctx), o.accessTokenKeySet, o.accessTokenVerifierOpts...)
|
||||
}
|
||||
|
||||
func (o *Provider) Crypto() Crypto {
|
||||
|
@ -397,6 +456,10 @@ func (o *Provider) Probes() []ProbesFn {
|
|||
}
|
||||
}
|
||||
|
||||
func (o *Provider) CORSOptions() *cors.Options {
|
||||
return o.corsOpts
|
||||
}
|
||||
|
||||
func (o *Provider) Logger() *slog.Logger {
|
||||
return o.logger
|
||||
}
|
||||
|
@ -406,13 +469,13 @@ func (o *Provider) HttpHandler() http.Handler {
|
|||
return o
|
||||
}
|
||||
|
||||
type openIDKeySet struct {
|
||||
type OpenIDKeySet struct {
|
||||
Storage
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface
|
||||
// providing an implementation for the keys stored in the OP Storage interface
|
||||
func (o *openIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
func (o *OpenIDKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
keySet, err := o.Storage.KeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching keys: %w", err)
|
||||
|
@ -543,6 +606,15 @@ func WithHttpInterceptors(interceptors ...HttpInterceptor) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithAccessTokenKeySet allows passing a KeySet with public keys for Access Token verification.
|
||||
// The default KeySet uses the [Storage] interface
|
||||
func WithAccessTokenKeySet(keySet oidc.KeySet) Option {
|
||||
return func(o *Provider) error {
|
||||
o.accessTokenKeySet = keySet
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option {
|
||||
return func(o *Provider) error {
|
||||
o.accessTokenVerifierOpts = opts
|
||||
|
@ -550,6 +622,15 @@ func WithAccessTokenVerifierOpts(opts ...AccessTokenVerifierOpt) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithIDTokenHintKeySet allows passing a KeySet with public keys for ID Token Hint verification.
|
||||
// The default KeySet uses the [Storage] interface.
|
||||
func WithIDTokenHintKeySet(keySet oidc.KeySet) Option {
|
||||
return func(o *Provider) error {
|
||||
o.idTokenHinKeySet = keySet
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
||||
return func(o *Provider) error {
|
||||
o.idTokenHintVerifierOpts = opts
|
||||
|
@ -557,6 +638,13 @@ func WithIDTokenHintVerifierOpts(opts ...IDTokenHintVerifierOpt) Option {
|
|||
}
|
||||
}
|
||||
|
||||
func WithCORSOptions(opts *cors.Options) Option {
|
||||
return func(o *Provider) error {
|
||||
o.corsOpts = opts
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger lets a logger other than slog.Default().
|
||||
//
|
||||
// EXPERIMENTAL: Will change to log/slog import after we drop support for Go 1.20
|
||||
|
@ -573,6 +661,6 @@ func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handle
|
|||
for i := len(interceptors) - 1; i >= 0; i-- {
|
||||
handler = interceptors[i](handler)
|
||||
}
|
||||
return cors.New(defaultCORSOptions).Handler(issuerInterceptor.Handler(handler))
|
||||
return issuerInterceptor.Handler(handler)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ var (
|
|||
AuthMethodPrivateKeyJWT: true,
|
||||
GrantTypeRefreshToken: true,
|
||||
RequestObjectSupported: true,
|
||||
SupportedClaims: op.DefaultSupportedClaims,
|
||||
SupportedUILocales: []language.Tag{language.English},
|
||||
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||
Lifetime: 5 * time.Minute,
|
||||
|
@ -57,8 +58,12 @@ func init() {
|
|||
}
|
||||
|
||||
func newTestProvider(config *op.Config) op.OpenIDProvider {
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, config,
|
||||
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
|
||||
storage := storage.NewStorage(storage.NewUserStore(testIssuer))
|
||||
keySet := &op.OpenIDKeySet{storage}
|
||||
provider, err := op.NewOpenIDProvider(testIssuer, config, storage,
|
||||
op.WithAllowInsecure(),
|
||||
op.WithAccessTokenKeySet(keySet),
|
||||
op.WithIDTokenHintKeySet(keySet),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -127,7 +127,7 @@ type Server interface {
|
|||
// Introspect handles the OAuth 2.0 Token Introspection endpoint.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7662
|
||||
// The recommended Response Data type is [oidc.IntrospectionResponse].
|
||||
Introspect(context.Context, *ClientRequest[oidc.IntrospectionRequest]) (*Response, error)
|
||||
Introspect(context.Context, *Request[IntrospectionRequest]) (*Response, error)
|
||||
|
||||
// UserInfo handles the UserInfo endpoint and returns Claims about the authenticated End-User.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
|
@ -329,7 +329,7 @@ func (UnimplementedServer) DeviceToken(ctx context.Context, r *ClientRequest[oid
|
|||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
|
||||
func (UnimplementedServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
func (UnimplementedServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) {
|
||||
return nil, unimplementedError(r)
|
||||
}
|
||||
|
||||
|
|
|
@ -25,9 +25,11 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
|
|||
decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
ws := &webServer{
|
||||
router: chi.NewRouter(),
|
||||
server: server,
|
||||
endpoints: endpoints,
|
||||
decoder: decoder,
|
||||
corsOpts: &defaultCORSOptions,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
|
||||
|
@ -36,6 +38,10 @@ func RegisterServer(server Server, endpoints Endpoints, options ...ServerOption)
|
|||
}
|
||||
|
||||
ws.createRouter()
|
||||
ws.handler = ws.router
|
||||
if ws.corsOpts != nil {
|
||||
ws.handler = cors.New(*ws.corsOpts).Handler(ws.router)
|
||||
}
|
||||
return ws
|
||||
}
|
||||
|
||||
|
@ -45,7 +51,14 @@ type ServerOption func(s *webServer)
|
|||
// the Server's router.
|
||||
func WithHTTPMiddleware(m ...func(http.Handler) http.Handler) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.middleware = m
|
||||
s.router.Use(m...)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSetRouter allows customization or the Server's router.
|
||||
func WithSetRouter(set func(chi.Router)) ServerOption {
|
||||
return func(s *webServer) {
|
||||
set(s.router)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,6 +70,13 @@ func WithDecoder(decoder httphelper.Decoder) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithServerCORSOptions sets the CORS policy for the Server's router.
|
||||
func WithServerCORSOptions(opts *cors.Options) ServerOption {
|
||||
return func(s *webServer) {
|
||||
s.corsOpts = opts
|
||||
}
|
||||
}
|
||||
|
||||
// WithFallbackLogger overrides the fallback logger, which
|
||||
// is used when no logger was found in the context.
|
||||
// Defaults to [slog.Default].
|
||||
|
@ -67,12 +87,17 @@ func WithFallbackLogger(logger *slog.Logger) ServerOption {
|
|||
}
|
||||
|
||||
type webServer struct {
|
||||
http.Handler
|
||||
server Server
|
||||
middleware []func(http.Handler) http.Handler
|
||||
endpoints Endpoints
|
||||
decoder httphelper.Decoder
|
||||
logger *slog.Logger
|
||||
server Server
|
||||
router *chi.Mux
|
||||
handler http.Handler
|
||||
endpoints Endpoints
|
||||
decoder httphelper.Decoder
|
||||
corsOpts *cors.Options
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *webServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
||||
|
@ -83,27 +108,23 @@ func (s *webServer) getLogger(ctx context.Context) *slog.Logger {
|
|||
}
|
||||
|
||||
func (s *webServer) createRouter() {
|
||||
router := chi.NewRouter()
|
||||
router.Use(cors.New(defaultCORSOptions).Handler)
|
||||
router.Use(s.middleware...)
|
||||
router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||
router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||
s.router.HandleFunc(healthEndpoint, simpleHandler(s, s.server.Health))
|
||||
s.router.HandleFunc(readinessEndpoint, simpleHandler(s, s.server.Ready))
|
||||
s.router.HandleFunc(oidc.DiscoveryEndpoint, simpleHandler(s, s.server.Discovery))
|
||||
|
||||
s.endpointRoute(router, s.endpoints.Authorization, s.authorizeHandler)
|
||||
s.endpointRoute(router, s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||
s.endpointRoute(router, s.endpoints.Token, s.tokensHandler)
|
||||
s.endpointRoute(router, s.endpoints.Introspection, s.withClient(s.introspectionHandler))
|
||||
s.endpointRoute(router, s.endpoints.Userinfo, s.userInfoHandler)
|
||||
s.endpointRoute(router, s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||
s.endpointRoute(router, s.endpoints.EndSession, s.endSessionHandler)
|
||||
s.endpointRoute(router, s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||
s.Handler = router
|
||||
s.endpointRoute(s.endpoints.Authorization, s.authorizeHandler)
|
||||
s.endpointRoute(s.endpoints.DeviceAuthorization, s.withClient(s.deviceAuthorizationHandler))
|
||||
s.endpointRoute(s.endpoints.Token, s.tokensHandler)
|
||||
s.endpointRoute(s.endpoints.Introspection, s.introspectionHandler)
|
||||
s.endpointRoute(s.endpoints.Userinfo, s.userInfoHandler)
|
||||
s.endpointRoute(s.endpoints.Revocation, s.withClient(s.revocationHandler))
|
||||
s.endpointRoute(s.endpoints.EndSession, s.endSessionHandler)
|
||||
s.endpointRoute(s.endpoints.JwksURI, simpleHandler(s, s.server.Keys))
|
||||
}
|
||||
|
||||
func (s *webServer) endpointRoute(router *chi.Mux, e *Endpoint, hf http.HandlerFunc) {
|
||||
func (s *webServer) endpointRoute(e *Endpoint, hf http.HandlerFunc) {
|
||||
if e != nil {
|
||||
router.HandleFunc(e.Relative(), hf)
|
||||
s.router.HandleFunc(e.Relative(), hf)
|
||||
s.logger.Info("registered route", "endpoint", e.Relative())
|
||||
}
|
||||
}
|
||||
|
@ -128,7 +149,21 @@ func (s *webServer) withClient(handler clientHandler) http.HandlerFunc {
|
|||
}
|
||||
|
||||
func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
||||
if err = r.ParseForm(); err != nil {
|
||||
cc, err := s.parseClientCredentials(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
Data: cc,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *webServer) parseClientCredentials(r *http.Request) (_ *ClientCredentials, err error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
|
||||
}
|
||||
cc := new(ClientCredentials)
|
||||
|
@ -152,13 +187,7 @@ func (s *webServer) verifyRequestClient(r *http.Request) (_ Client, err error) {
|
|||
if cc.ClientAssertion != "" && cc.ClientAssertionType != oidc.ClientAssertionTypeJWTAssertion {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("invalid client_assertion_type %s", cc.ClientAssertionType)
|
||||
}
|
||||
return s.server.VerifyClient(r.Context(), &Request[ClientCredentials]{
|
||||
Method: r.Method,
|
||||
URL: r.URL,
|
||||
Header: r.Header,
|
||||
Form: r.Form,
|
||||
Data: cc,
|
||||
})
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (s *webServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -370,8 +399,13 @@ func (s *webServer) deviceTokenHandler(w http.ResponseWriter, r *http.Request, c
|
|||
resp.writeOut(w)
|
||||
}
|
||||
|
||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request, client Client) {
|
||||
if client.AuthMethod() == oidc.AuthMethodNone {
|
||||
func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
cc, err := s.parseClientCredentials(r)
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
if cc.ClientSecret == "" && cc.ClientAssertion == "" {
|
||||
WriteError(w, r, oidc.ErrInvalidClient().WithDescription("client must be authenticated"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
|
@ -384,7 +418,7 @@ func (s *webServer) introspectionHandler(w http.ResponseWriter, r *http.Request,
|
|||
WriteError(w, r, oidc.ErrInvalidRequest().WithDescription("token missing"), s.getLogger(r.Context()))
|
||||
return
|
||||
}
|
||||
resp, err := s.server.Introspect(r.Context(), newClientRequest(r, request, client))
|
||||
resp, err := s.server.Introspect(r.Context(), newRequest(r, &IntrospectionRequest{cc, request}))
|
||||
if err != nil {
|
||||
WriteError(w, r, err, s.getLogger(r.Context()))
|
||||
return
|
||||
|
|
|
@ -32,7 +32,7 @@ func jwtProfile() (string, error) {
|
|||
}
|
||||
|
||||
func TestServerRoutes(t *testing.T) {
|
||||
server := op.NewLegacyServer(testProvider, *op.DefaultEndpoints)
|
||||
server := op.RegisterLegacyServer(op.NewLegacyServer(testProvider, *op.DefaultEndpoints))
|
||||
|
||||
storage := testProvider.Storage().(routesTestStorage)
|
||||
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
|
||||
|
|
|
@ -365,14 +365,14 @@ func Test_webServer_authorizeHandler(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
name: "authorize error",
|
||||
name: "server error",
|
||||
fields: fields{
|
||||
server: &requestVerifier{},
|
||||
decoder: testDecoder,
|
||||
},
|
||||
r: httptest.NewRequest(http.MethodPost, "/authorize", strings.NewReader("foo=bar")),
|
||||
want: webServerResult{
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":"server_error"}`,
|
||||
},
|
||||
},
|
||||
|
@ -1001,14 +1001,12 @@ func Test_webServer_introspectionHandler(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
decoder httphelper.Decoder
|
||||
client Client
|
||||
r *http.Request
|
||||
want webServerResult
|
||||
}{
|
||||
{
|
||||
name: "decoder error",
|
||||
decoder: schema.NewDecoder(),
|
||||
client: newClient(clientTypeUserAgent),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
||||
want: webServerResult{
|
||||
wantStatus: http.StatusBadRequest,
|
||||
|
@ -1018,8 +1016,7 @@ func Test_webServer_introspectionHandler(t *testing.T) {
|
|||
{
|
||||
name: "public client",
|
||||
decoder: testDecoder,
|
||||
client: newClient(clientTypeNative),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123")),
|
||||
want: webServerResult{
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":"invalid_client", "error_description":"client must be authenticated"}`,
|
||||
|
@ -1028,8 +1025,7 @@ func Test_webServer_introspectionHandler(t *testing.T) {
|
|||
{
|
||||
name: "token missing",
|
||||
decoder: testDecoder,
|
||||
client: newClient(clientTypeWeb),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET")),
|
||||
want: webServerResult{
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":"invalid_request", "error_description":"token missing"}`,
|
||||
|
@ -1038,8 +1034,7 @@ func Test_webServer_introspectionHandler(t *testing.T) {
|
|||
{
|
||||
name: "unimplemented Introspect called",
|
||||
decoder: testDecoder,
|
||||
client: newClient(clientTypeWeb),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("token=xxx")),
|
||||
r: httptest.NewRequest(http.MethodPost, "/", strings.NewReader("client_id=123&client_secret=SECRET&token=xxx")),
|
||||
want: webServerResult{
|
||||
wantStatus: UnimplementedStatusCode,
|
||||
wantBody: `{"error":"server_error", "error_description":"/ not implemented on this server"}`,
|
||||
|
@ -1053,7 +1048,7 @@ func Test_webServer_introspectionHandler(t *testing.T) {
|
|||
decoder: tt.decoder,
|
||||
logger: slog.Default(),
|
||||
}
|
||||
runWebServerClientTest(t, s.introspectionHandler, tt.r, tt.client, tt.want)
|
||||
runWebServerTest(t, s.introspectionHandler, tt.r, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1242,7 +1237,7 @@ func Test_webServer_simpleHandler(t *testing.T) {
|
|||
},
|
||||
r: httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(make([]byte, 11<<20))),
|
||||
want: webServerResult{
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":"server_error", "error_description":"io: read/write on closed pipe"}`,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -10,37 +10,77 @@ import (
|
|||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
||||
// LegacyServer is an implementation of [Server[] that
|
||||
// simply wraps a [OpenIDProvider].
|
||||
// ExtendedLegacyServer allows embedding [LegacyServer] in a struct,
|
||||
// so that its methods can be individually overridden.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type ExtendedLegacyServer interface {
|
||||
Server
|
||||
Provider() OpenIDProvider
|
||||
Endpoints() Endpoints
|
||||
AuthCallbackURL() func(context.Context, string) string
|
||||
}
|
||||
|
||||
// RegisterLegacyServer registers a [LegacyServer] or an extension thereof.
|
||||
// It takes care of registering the IssuerFromRequest middleware
|
||||
// and Authorization Callback Routes.
|
||||
// Neither are part of the bare [Server] interface.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func RegisterLegacyServer(s ExtendedLegacyServer, options ...ServerOption) http.Handler {
|
||||
provider := s.Provider()
|
||||
options = append(options,
|
||||
WithHTTPMiddleware(intercept(provider.IssuerFromRequest)),
|
||||
WithSetRouter(func(r chi.Router) {
|
||||
r.HandleFunc(s.Endpoints().Authorization.Relative()+authCallbackPathSuffix, authorizeCallbackHandler(provider))
|
||||
}),
|
||||
)
|
||||
return RegisterServer(s, s.Endpoints(), options...)
|
||||
}
|
||||
|
||||
// LegacyServer is an implementation of [Server] that
|
||||
// simply wraps an [OpenIDProvider].
|
||||
// It can be used to transition from the former Provider/Storage
|
||||
// interfaces to the new Server interface.
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
type LegacyServer struct {
|
||||
UnimplementedServer
|
||||
provider OpenIDProvider
|
||||
endpoints Endpoints
|
||||
}
|
||||
|
||||
// NewLegacyServer wraps provider in a `Server` and returns a handler which is
|
||||
// the Server's router.
|
||||
// NewLegacyServer wraps provider in a `Server` implementation
|
||||
//
|
||||
// Only non-nil endpoints will be registered on the router.
|
||||
// Nil endpoints are disabled.
|
||||
//
|
||||
// The passed endpoints is also set to the provider,
|
||||
// to be consistent with the discovery config.
|
||||
// The passed endpoints is also used for the discovery config,
|
||||
// and endpoints already set to the provider are ignored.
|
||||
// Any `With*Endpoint()` option used on the provider is
|
||||
// therefore ineffective.
|
||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) http.Handler {
|
||||
server := RegisterServer(&LegacyServer{
|
||||
//
|
||||
// EXPERIMENTAL: may change until v4
|
||||
func NewLegacyServer(provider OpenIDProvider, endpoints Endpoints) *LegacyServer {
|
||||
return &LegacyServer{
|
||||
provider: provider,
|
||||
endpoints: endpoints,
|
||||
}, endpoints, WithHTTPMiddleware(intercept(provider.IssuerFromRequest)))
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", server)
|
||||
router.HandleFunc(authCallbackPath(provider), authorizeCallbackHandler(provider))
|
||||
func (s *LegacyServer) Provider() OpenIDProvider {
|
||||
return s.provider
|
||||
}
|
||||
|
||||
return router
|
||||
func (s *LegacyServer) Endpoints() Endpoints {
|
||||
return s.endpoints
|
||||
}
|
||||
|
||||
// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
|
||||
func (s *LegacyServer) AuthCallbackURL() func(context.Context, string) string {
|
||||
return func(ctx context.Context, requestID string) string {
|
||||
return s.endpoints.Authorization.Absolute(IssuerFromContext(ctx)) + authCallbackPathSuffix + "?id=" + requestID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Health(_ context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
|
@ -51,7 +91,7 @@ func (s *LegacyServer) Ready(ctx context.Context, r *Request[struct{}]) (*Respon
|
|||
for _, probe := range s.provider.Probes() {
|
||||
// shouldn't we run probes in Go routines?
|
||||
if err := probe(ctx); err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return NewResponse(Status{Status: "ok"}), nil
|
||||
|
@ -66,7 +106,7 @@ func (s *LegacyServer) Discovery(ctx context.Context, r *Request[struct{}]) (*Re
|
|||
func (s *LegacyServer) Keys(ctx context.Context, r *Request[struct{}]) (*Response, error) {
|
||||
keys, err := s.provider.Storage().KeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(jsonWebKeySet(keys)), nil
|
||||
}
|
||||
|
@ -87,7 +127,7 @@ func (s *LegacyServer) VerifyAuthRequest(ctx context.Context, r *Request[oidc.Au
|
|||
}
|
||||
}
|
||||
if r.Data.ClientID == "" {
|
||||
return nil, ErrAuthReqMissingClientID
|
||||
return nil, oidc.ErrInvalidRequest().WithParent(ErrAuthReqMissingClientID).WithDescription(ErrAuthReqMissingClientID.Error())
|
||||
}
|
||||
client, err := s.provider.Storage().GetClientByClientID(ctx, r.Data.ClientID)
|
||||
if err != nil {
|
||||
|
@ -115,7 +155,7 @@ func (s *LegacyServer) Authorize(ctx context.Context, r *ClientRequest[oidc.Auth
|
|||
func (s *LegacyServer) DeviceAuthorization(ctx context.Context, r *ClientRequest[oidc.DeviceAuthorizationRequest]) (*Response, error) {
|
||||
response, err := createDeviceAuthorization(ctx, r.Data, r.Client.GetID(), s.provider)
|
||||
if err != nil {
|
||||
return nil, NewStatusError(err, http.StatusInternalServerError)
|
||||
return nil, AsStatusError(err, http.StatusInternalServerError)
|
||||
}
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
@ -165,11 +205,14 @@ func (s *LegacyServer) CodeExchange(ctx context.Context, r *ClientRequest[oidc.A
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.Client.AuthMethod() == oidc.AuthMethodNone {
|
||||
if r.Client.AuthMethod() == oidc.AuthMethodNone || r.Data.CodeVerifier != "" {
|
||||
if err = AuthorizeCodeChallenge(r.Data.CodeVerifier, authReq.GetCodeChallenge()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if r.Data.RedirectURI != authReq.GetRedirectURI() {
|
||||
return nil, oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
||||
}
|
||||
resp, err := CreateTokenResponse(ctx, authReq, r.Client, s.provider, true, r.Data.Code, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -205,7 +248,7 @@ func (s *LegacyServer) JWTProfile(ctx context.Context, r *Request[oidc.JWTProfil
|
|||
}
|
||||
tokenRequest, err := VerifyJWTAssertion(ctx, r.Data.Assertion, exchanger.JWTProfileVerifier(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, oidc.ErrInvalidRequest().WithParent(err).WithDescription("assertion invalid")
|
||||
}
|
||||
|
||||
tokenRequest.Scopes, err = exchanger.Storage().ValidateJWTProfileScopes(ctx, tokenRequest.Issuer, r.Data.Scope)
|
||||
|
@ -251,7 +294,7 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR
|
|||
}
|
||||
|
||||
func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
|
||||
if !s.provider.GrantTypeClientCredentialsSupported() {
|
||||
if !s.provider.GrantTypeDeviceCodeSupported() {
|
||||
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
|
||||
}
|
||||
// use a limited context timeout shorter as the default
|
||||
|
@ -259,15 +302,10 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De
|
|||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
|
||||
tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
subject: state.Subject,
|
||||
audience: []string{r.Client.GetID()},
|
||||
scopes: state.Scopes,
|
||||
}
|
||||
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -275,13 +313,30 @@ func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.De
|
|||
return NewResponse(resp), nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Introspect(ctx context.Context, r *ClientRequest[oidc.IntrospectionRequest]) (*Response, error) {
|
||||
func (s *LegacyServer) authenticateResourceClient(ctx context.Context, cc *ClientCredentials) (string, error) {
|
||||
if cc.ClientAssertion != "" {
|
||||
if jp, ok := s.provider.(ClientJWTProfile); ok {
|
||||
return ClientJWTAuth(ctx, oidc.ClientAssertionParams{ClientAssertion: cc.ClientAssertion}, jp)
|
||||
}
|
||||
return "", oidc.ErrInvalidClient().WithDescription("client_assertion not supported")
|
||||
}
|
||||
if err := s.provider.Storage().AuthorizeClientIDSecret(ctx, cc.ClientID, cc.ClientSecret); err != nil {
|
||||
return "", oidc.ErrUnauthorizedClient().WithParent(err)
|
||||
}
|
||||
return cc.ClientID, nil
|
||||
}
|
||||
|
||||
func (s *LegacyServer) Introspect(ctx context.Context, r *Request[IntrospectionRequest]) (*Response, error) {
|
||||
clientID, err := s.authenticateResourceClient(ctx, r.Data.ClientCredentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := new(oidc.IntrospectionResponse)
|
||||
tokenID, subject, ok := getTokenIDAndSubject(ctx, s.provider, r.Data.Token)
|
||||
if !ok {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
err := s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, r.Client.GetID())
|
||||
err = s.provider.Storage().SetIntrospectionFromToken(ctx, response, tokenID, subject, clientID)
|
||||
if err != nil {
|
||||
return NewResponse(response), nil
|
||||
}
|
||||
|
@ -336,9 +391,14 @@ func (s *LegacyServer) EndSession(ctx context.Context, r *Request[oidc.EndSessio
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
|
||||
redirect := session.RedirectURI
|
||||
if fromRequest, ok := s.provider.Storage().(CanTerminateSessionFromRequest); ok {
|
||||
redirect, err = fromRequest.TerminateSessionFromRequest(ctx, session)
|
||||
} else {
|
||||
err = s.provider.Storage().TerminateSession(ctx, session.UserID, session.ClientID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRedirect(session.RedirectURI), nil
|
||||
return NewRedirect(redirect), nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
@ -68,7 +69,7 @@ func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest,
|
|||
}
|
||||
if req.IdTokenHint != "" {
|
||||
claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
|
||||
if err != nil {
|
||||
if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
|
||||
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
|
||||
}
|
||||
session.UserID = claims.GetSubject()
|
||||
|
|
|
@ -168,15 +168,6 @@ type EndSessionRequest struct {
|
|||
|
||||
var ErrDuplicateUserCode = errors.New("user code already exists")
|
||||
|
||||
type DeviceAuthorizationState struct {
|
||||
ClientID string
|
||||
Scopes []string
|
||||
Expires time.Time
|
||||
Done bool
|
||||
Subject string
|
||||
Denied bool
|
||||
}
|
||||
|
||||
type DeviceAuthorizationStorage interface {
|
||||
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
|
||||
// User code will be used by the user to complete the login flow and must be unique.
|
||||
|
|
|
@ -84,6 +84,8 @@ func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool
|
|||
return req.GetRequestedTokenType() == oidc.RefreshTokenType
|
||||
case RefreshTokenRequest:
|
||||
return true
|
||||
case *DeviceAuthorizationState:
|
||||
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -65,3 +65,8 @@ func ParseTokenIntrospectionRequest(r *http.Request, introspector Introspector)
|
|||
|
||||
return req.Token, clientID, nil
|
||||
}
|
||||
|
||||
type IntrospectionRequest struct {
|
||||
*ClientCredentials
|
||||
*oidc.IntrospectionRequest
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
)
|
||||
|
@ -27,8 +28,23 @@ func NewIDTokenHintVerifier(issuer string, keySet oidc.KeySet, opts ...IDTokenHi
|
|||
return verifier
|
||||
}
|
||||
|
||||
type IDTokenHintExpiredError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (e IDTokenHintExpiredError) Unwrap() error {
|
||||
return e.error
|
||||
}
|
||||
|
||||
func (e IDTokenHintExpiredError) Is(err error) bool {
|
||||
return errors.Is(err, e.error)
|
||||
}
|
||||
|
||||
// VerifyIDTokenHint validates the id token according to
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation.
|
||||
// In case of an expired token both the Claims and first encountered expiry related error
|
||||
// is returned of type [IDTokenHintExpiredError]. In that case the caller can choose to still
|
||||
// trust the token for cases like logout, as signature and other verifications succeeded.
|
||||
func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTokenHintVerifier) (claims C, err error) {
|
||||
var nilClaims C
|
||||
|
||||
|
@ -49,20 +65,20 @@ func VerifyIDTokenHint[C oidc.Claims](ctx context.Context, token string, v *IDTo
|
|||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthorizationContextClassReference(claims, v.ACR); err != nil {
|
||||
return nilClaims, err
|
||||
}
|
||||
|
||||
if err = oidc.CheckExpiration(claims, v.Offset); err != nil {
|
||||
return claims, IDTokenHintExpiredError{err}
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuedAt(claims, v.MaxAgeIAT, v.Offset); err != nil {
|
||||
return claims, IDTokenHintExpiredError{err}
|
||||
}
|
||||
|
||||
if err = oidc.CheckAuthTime(claims, v.MaxAge); err != nil {
|
||||
return nilClaims, err
|
||||
return claims, IDTokenHintExpiredError{err}
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package op
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -57,6 +58,13 @@ func TestNewIDTokenHintVerifier(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_IDTokenHintExpiredError(t *testing.T) {
|
||||
var err error = IDTokenHintExpiredError{oidc.ErrExpired}
|
||||
assert.True(t, errors.Unwrap(err) == oidc.ErrExpired)
|
||||
assert.ErrorIs(t, err, oidc.ErrExpired)
|
||||
assert.ErrorAs(t, err, &IDTokenHintExpiredError{})
|
||||
}
|
||||
|
||||
func TestVerifyIDTokenHint(t *testing.T) {
|
||||
verifier := &IDTokenHintVerifier{
|
||||
Issuer: tu.ValidIssuer,
|
||||
|
@ -71,21 +79,23 @@ func TestVerifyIDTokenHint(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
tokenClaims func() (string, *oidc.IDTokenClaims)
|
||||
wantErr bool
|
||||
wantClaims bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
tokenClaims: tu.ValidIDToken,
|
||||
wantClaims: true,
|
||||
},
|
||||
{
|
||||
name: "parse err",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return "~~~~", nil },
|
||||
wantErr: true,
|
||||
wantErr: oidc.ErrParse,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) { return tu.InvalidSignatureToken, nil },
|
||||
wantErr: true,
|
||||
wantErr: oidc.ErrSignatureUnsupportedAlg,
|
||||
},
|
||||
{
|
||||
name: "wrong issuer",
|
||||
|
@ -96,29 +106,7 @@ func TestVerifyIDTokenHint(t *testing.T) {
|
|||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong IAT",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, -time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErr: oidc.ErrIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "wrong acr",
|
||||
|
@ -129,7 +117,31 @@ func TestVerifyIDTokenHint(t *testing.T) {
|
|||
"else", tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErr: oidc.ErrAcrInvalid,
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration.Add(-time.Hour), tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantClaims: true,
|
||||
wantErr: IDTokenHintExpiredError{oidc.ErrExpired},
|
||||
},
|
||||
{
|
||||
name: "IAT too old",
|
||||
tokenClaims: func() (string, *oidc.IDTokenClaims) {
|
||||
return tu.NewIDToken(
|
||||
tu.ValidIssuer, tu.ValidSubject, tu.ValidAudience,
|
||||
tu.ValidExpiration, tu.ValidAuthTime, tu.ValidNonce,
|
||||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, time.Hour, "",
|
||||
)
|
||||
},
|
||||
wantClaims: true,
|
||||
wantErr: IDTokenHintExpiredError{oidc.ErrIatToOld},
|
||||
},
|
||||
{
|
||||
name: "expired auth",
|
||||
|
@ -140,7 +152,8 @@ func TestVerifyIDTokenHint(t *testing.T) {
|
|||
tu.ValidACR, tu.ValidAMR, tu.ValidClientID, tu.ValidSkew, "",
|
||||
)
|
||||
},
|
||||
wantErr: true,
|
||||
wantClaims: true,
|
||||
wantErr: IDTokenHintExpiredError{oidc.ErrAuthTimeToOld},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -148,14 +161,12 @@ func TestVerifyIDTokenHint(t *testing.T) {
|
|||
token, want := tt.tokenClaims()
|
||||
|
||||
got, err := VerifyIDTokenHint[*oidc.IDTokenClaims](context.Background(), token, verifier)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.wantClaims {
|
||||
assert.Equal(t, got, want, "claims")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, got, want)
|
||||
assert.Nil(t, got, "claims")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,21 @@ import (
|
|||
type JWTProfileVerifier struct {
|
||||
oidc.Verifier
|
||||
Storage JWTProfileKeyStorage
|
||||
keySet oidc.KeySet
|
||||
CheckSubject func(request *oidc.JWTTokenRequest) error
|
||||
}
|
||||
|
||||
// NewJWTProfileVerifier creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
|
||||
func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
|
||||
return newJWTProfileVerifier(storage, nil, issuer, maxAgeIAT, offset, opts...)
|
||||
}
|
||||
|
||||
// NewJWTProfileVerifierKeySet creates a oidc.Verifier for JWT Profile assertions (authorization grant and client authentication)
|
||||
func NewJWTProfileVerifierKeySet(keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
|
||||
return newJWTProfileVerifier(nil, keySet, issuer, maxAgeIAT, offset, opts...)
|
||||
}
|
||||
|
||||
func newJWTProfileVerifier(storage JWTProfileKeyStorage, keySet oidc.KeySet, issuer string, maxAgeIAT, offset time.Duration, opts ...JWTProfileVerifierOption) *JWTProfileVerifier {
|
||||
j := &JWTProfileVerifier{
|
||||
Verifier: oidc.Verifier{
|
||||
Issuer: issuer,
|
||||
|
@ -29,6 +39,7 @@ func NewJWTProfileVerifier(storage JWTProfileKeyStorage, issuer string, maxAgeIA
|
|||
Offset: offset,
|
||||
},
|
||||
Storage: storage,
|
||||
keySet: keySet,
|
||||
CheckSubject: SubjectIsIssuer,
|
||||
}
|
||||
|
||||
|
@ -78,7 +89,10 @@ func VerifyJWTAssertion(ctx context.Context, assertion string, v *JWTProfileVeri
|
|||
return nil, err
|
||||
}
|
||||
|
||||
keySet := &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
|
||||
keySet := v.keySet
|
||||
if keySet == nil {
|
||||
keySet = &jwtProfileKeySet{storage: v.Storage, clientID: request.Issuer}
|
||||
}
|
||||
if err = oidc.CheckSignature(ctx, assertion, payload, request, nil, keySet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue