initial commit
This commit is contained in:
commit
6d0890e280
68 changed files with 5986 additions and 0 deletions
29
.github/workflows/release.yml
vendored
Normal file
29
.github/workflows/release.yml
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
name: Release
|
||||
on: push
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-18.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.11', '1.12', '1.13']
|
||||
name: Go ${{ matrix.go }} test
|
||||
steps:
|
||||
- uses: actions/checkout@master
|
||||
- name: Setup go
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: go test -race ./pkg/...
|
||||
release:
|
||||
runs-on: ubuntu-18.04
|
||||
needs: [test]
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- name: Source checkout
|
||||
uses: actions/checkout@v1
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Create Version
|
||||
uses: caos/semantic-release@v0.2.4
|
||||
|
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, build with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
**/__debug_bin
|
||||
.vscode
|
||||
.DS_Store
|
7
.releaserc.js
Normal file
7
.releaserc.js
Normal file
|
@ -0,0 +1,7 @@
|
|||
module.exports = {
|
||||
branch: 'master',
|
||||
plugins: [
|
||||
"@semantic-release/commit-analyzer",
|
||||
"@semantic-release/release-notes-generator"
|
||||
]
|
||||
};
|
201
LICENSE
Normal file
201
LICENSE
Normal file
|
@ -0,0 +1,201 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
8
README.md
Normal file
8
README.md
Normal file
|
@ -0,0 +1,8 @@
|
|||
# oidc
|
||||
|
||||

|
||||

|
||||
[](https://GitHub.com/caos/oidc/releases/)
|
||||
[](https://github.com/caos/oidc/blob/master/LICENSE)
|
||||
|
||||
OpenID Connect SDK (client and server) for Go
|
42
SECURITY.md
Normal file
42
SECURITY.md
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Security Policy
|
||||
|
||||
At caos we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team.
|
||||
|
||||
## Supported Versions
|
||||
|
||||
After the initial Release the following version support will apply
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| 1.x.x | :white_check_mark: (note yet available) |
|
||||
| 0.x.x | :x: |
|
||||
|
||||
## Reporting a vulnerability
|
||||
|
||||
To file a incident, please disclose by email to security@caos.ch with the security details.
|
||||
|
||||
At the moment GPG encryption is no yet supported, however you may sign your message at will.
|
||||
|
||||
### When should I report a vulnerability
|
||||
|
||||
* You think you discovered a ...
|
||||
* ... potential security vulnerability in the SDK
|
||||
* ... vulnerability in another project that this SDK bases on
|
||||
* For projects with their own vulnerability reporting and disclosure process, please report it directly there
|
||||
|
||||
### When should I NOT report a vulnerability
|
||||
|
||||
* You need help applying security related updates
|
||||
* Your issue is not security related
|
||||
|
||||
## Security Vulnerability Response
|
||||
|
||||
TBD
|
||||
|
||||
## Public Disclosure
|
||||
|
||||
All accepted and mitigated vulnerabilitys will be published on the [Github Security Page](https://github.com/caos/oidc/security/advisories)
|
||||
|
||||
### Timing
|
||||
|
||||
We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the discloures the time frame can range from 7 to 90 days.
|
1
doc.go
Normal file
1
doc.go
Normal file
|
@ -0,0 +1 @@
|
|||
package oidc
|
90
example/client/api/api.go
Normal file
90
example/client/api/api.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package main
|
||||
|
||||
// import (
|
||||
// "encoding/json"
|
||||
// "fmt"
|
||||
// "log"
|
||||
// "net/http"
|
||||
// "os"
|
||||
|
||||
// "github.com/caos/oidc/pkg/oidc"
|
||||
// "github.com/caos/oidc/pkg/oidc/rp"
|
||||
// "github.com/caos/utils/logging"
|
||||
// )
|
||||
|
||||
// const (
|
||||
// publicURL string = "/public"
|
||||
// protectedURL string = "/protected"
|
||||
// protectedExchangeURL string = "/protected/exchange"
|
||||
// )
|
||||
|
||||
func main() {
|
||||
// clientID := os.Getenv("CLIENT_ID")
|
||||
// clientSecret := os.Getenv("CLIENT_SECRET")
|
||||
// issuer := os.Getenv("ISSUER")
|
||||
// port := os.Getenv("PORT")
|
||||
|
||||
// // ctx := context.Background()
|
||||
|
||||
// providerConfig := &oidc.ProviderConfig{
|
||||
// ClientID: clientID,
|
||||
// ClientSecret: clientSecret,
|
||||
// Issuer: issuer,
|
||||
// }
|
||||
// provider, err := rp.NewDefaultProvider(providerConfig)
|
||||
// logging.Log("APP-nx6PeF").OnError(err).Panic("error creating provider")
|
||||
|
||||
// http.HandleFunc(publicURL, func(w http.ResponseWriter, r *http.Request) {
|
||||
// w.Write([]byte("OK"))
|
||||
// })
|
||||
|
||||
// http.HandleFunc(protectedURL, func(w http.ResponseWriter, r *http.Request) {
|
||||
// ok, token := checkToken(w, r)
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// resp, err := provider.Introspect(r.Context(), token)
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), http.StatusForbidden)
|
||||
// return
|
||||
// }
|
||||
// data, err := json.Marshal(resp)
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// w.Write(data)
|
||||
// })
|
||||
|
||||
// http.HandleFunc(protectedExchangeURL, func(w http.ResponseWriter, r *http.Request) {
|
||||
// ok, token := checkToken(w, r)
|
||||
// if !ok {
|
||||
// return
|
||||
// }
|
||||
// tokens, err := provider.DelegationTokenExchange(r.Context(), token, oidc.WithResource([]string{"Test"}))
|
||||
// if err != nil {
|
||||
// http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
|
||||
// data, err := json.Marshal(tokens)
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// w.Write(data)
|
||||
// })
|
||||
|
||||
// lis := fmt.Sprintf("127.0.0.1:%s", port)
|
||||
// log.Printf("listening on http://%s/", lis)
|
||||
// log.Fatal(http.ListenAndServe(lis, nil))
|
||||
// }
|
||||
|
||||
// func checkToken(w http.ResponseWriter, r *http.Request) (bool, string) {
|
||||
// token := r.Header.Get("authorization")
|
||||
// if token == "" {
|
||||
// http.Error(w, "Auth header missing", http.StatusUnauthorized)
|
||||
// return false, ""
|
||||
// }
|
||||
// return true, token
|
||||
}
|
96
example/client/app/app.go
Normal file
96
example/client/app/app.go
Normal file
|
@ -0,0 +1,96 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/rp"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
callbackPath string = "/auth/callback"
|
||||
key []byte = []byte("test1234test1234")
|
||||
)
|
||||
|
||||
func main() {
|
||||
clientID := os.Getenv("CLIENT_ID")
|
||||
clientSecret := os.Getenv("CLIENT_SECRET")
|
||||
issuer := os.Getenv("ISSUER")
|
||||
port := os.Getenv("PORT")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
rpConfig := &rp.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Issuer: issuer,
|
||||
CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath),
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
}
|
||||
cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure())
|
||||
provider, err := rp.NewDefaultRP(rpConfig, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //,
|
||||
if err != nil {
|
||||
logrus.Fatalf("error creating provider %s", err.Error())
|
||||
}
|
||||
|
||||
// state := "foobar"
|
||||
state := uuid.New().String()
|
||||
|
||||
http.Handle("/login", provider.AuthURLHandler(state))
|
||||
// http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
// http.Redirect(w, r, provider.AuthURL(state), http.StatusFound)
|
||||
// })
|
||||
|
||||
// http.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
// tokens, err := provider.CodeExchange(ctx, r.URL.Query().Get("code"))
|
||||
// if err != nil {
|
||||
// http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
// data, err := json.Marshal(tokens)
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// return
|
||||
// }
|
||||
// w.Write(data)
|
||||
// })
|
||||
|
||||
marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) {
|
||||
_ = state
|
||||
data, err := json.Marshal(tokens)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
http.Handle(callbackPath, provider.CodeExchangeHandler(marshal))
|
||||
|
||||
http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
tokens, err := provider.ClientCredentials(ctx, "scope")
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(tokens)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
})
|
||||
lis := fmt.Sprintf("127.0.0.1:%s", port)
|
||||
logrus.Infof("listening on http://%s/", lis)
|
||||
logrus.Fatal(http.ListenAndServe("127.0.0.1:5556", nil))
|
||||
}
|
1
example/doc.go
Normal file
1
example/doc.go
Normal file
|
@ -0,0 +1 @@
|
|||
package example
|
250
example/internal/mock/storage.go
Normal file
250
example/internal/mock/storage.go
Normal file
|
@ -0,0 +1,250 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
type AuthStorage struct {
|
||||
key *rsa.PrivateKey
|
||||
}
|
||||
|
||||
func NewAuthStorage() op.Storage {
|
||||
reader := rand.Reader
|
||||
bitSize := 2048
|
||||
key, err := rsa.GenerateKey(reader, bitSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &AuthStorage{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
ResponseType oidc.ResponseType
|
||||
RedirectURI string
|
||||
Nonce string
|
||||
ClientID string
|
||||
CodeChallenge *oidc.CodeChallenge
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetACR() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAMR() []string {
|
||||
return []string{
|
||||
"password",
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAudience() []string {
|
||||
return []string{
|
||||
a.ClientID,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAuthTime() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetClientID() string {
|
||||
return a.ClientID
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetCode() string {
|
||||
return "code"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
|
||||
return a.CodeChallenge
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetNonce() string {
|
||||
return a.Nonce
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
// return "http://localhost:5556/auth/callback"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetResponseType() oidc.ResponseType {
|
||||
return a.ResponseType
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetScopes() []string {
|
||||
return []string{
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetState() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetSubject() string {
|
||||
return "sub"
|
||||
}
|
||||
|
||||
func (a *AuthRequest) Done() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
var (
|
||||
a = &AuthRequest{}
|
||||
t bool
|
||||
)
|
||||
|
||||
func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI}
|
||||
if authReq.CodeChallenge != "" {
|
||||
a.CodeChallenge = &oidc.CodeChallenge{
|
||||
Challenge: authReq.CodeChallenge,
|
||||
Method: authReq.CodeChallengeMethod,
|
||||
}
|
||||
}
|
||||
t = false
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) {
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error {
|
||||
t = true
|
||||
return nil
|
||||
}
|
||||
func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) {
|
||||
if id != "id" || t {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) {
|
||||
return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil
|
||||
}
|
||||
func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) {
|
||||
return s.key, nil
|
||||
}
|
||||
func (s *AuthStorage) SaveKeyPair(ctx context.Context) (*jose.SigningKey, error) {
|
||||
return s.GetSigningKey(ctx)
|
||||
}
|
||||
func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) {
|
||||
pubkey := s.key.Public()
|
||||
return &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) {
|
||||
if id == "none" {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
var appType op.ApplicationType
|
||||
var authMethod op.AuthMethod
|
||||
var accessTokenType op.AccessTokenType
|
||||
if id == "web" {
|
||||
appType = op.ApplicationTypeWeb
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
} else if id == "native" {
|
||||
appType = op.ApplicationTypeNative
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
} else {
|
||||
appType = op.ApplicationTypeUserAgent
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeJWT
|
||||
}
|
||||
return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) {
|
||||
return &oidc.Userinfo{
|
||||
Subject: a.GetSubject(),
|
||||
Address: &oidc.UserinfoAddress{
|
||||
StreetAddress: "Hjkhkj 789\ndsf",
|
||||
},
|
||||
UserinfoEmail: oidc.UserinfoEmail{
|
||||
Email: "test",
|
||||
EmailVerified: true,
|
||||
},
|
||||
UserinfoPhone: oidc.UserinfoPhone{
|
||||
PhoneNumber: "sadsa",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
UserinfoProfile: oidc.UserinfoProfile{
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
// Claims: map[string]interface{}{
|
||||
// "test": "test",
|
||||
// "hkjh": "",
|
||||
// },
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ConfClient struct {
|
||||
applicationType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
ID string
|
||||
accessTokenType op.AccessTokenType
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetID() string {
|
||||
return c.ID
|
||||
}
|
||||
func (c *ConfClient) RedirectURIs() []string {
|
||||
return []string{
|
||||
"https://registered.com/callback",
|
||||
"http://localhost:9999/callback",
|
||||
"http://localhost:5556/auth/callback",
|
||||
"custom://callback",
|
||||
"https://localhost:8443/test/a/instructions-example/callback",
|
||||
"https://op.certification.openid.net:62064/authz_cb",
|
||||
"https://op.certification.openid.net:62064/authz_post",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfClient) LoginURL(id string) string {
|
||||
return "login?id=" + id
|
||||
}
|
||||
|
||||
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||
return c.applicationType
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||
return c.accessTokenType
|
||||
}
|
65
example/server/default/default.go
Normal file
65
example/server/default/default.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/caos/oidc/example/internal/mock"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
config := &op.Config{
|
||||
Issuer: "http://localhost:9998/",
|
||||
CryptoKey: sha256.Sum256([]byte("test")),
|
||||
Port: "9998",
|
||||
}
|
||||
storage := mock.NewAuthStorage()
|
||||
handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
router := handler.HttpHandler().Handler.(*mux.Router)
|
||||
router.Methods("GET").Path("/login").HandlerFunc(HandleLogin)
|
||||
router.Methods("POST").Path("/login").HandlerFunc(HandleCallback)
|
||||
op.Start(ctx, handler)
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
tpl := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Login</title>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST" action="/login">
|
||||
<input name="client"/>
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>`
|
||||
t, err := template.New("login").Parse(tpl)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
err = t.Execute(w, nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
r.ParseForm()
|
||||
client := r.FormValue("client")
|
||||
http.Redirect(w, r, "/authorize/"+client, http.StatusFound)
|
||||
}
|
25
go.mod
Normal file
25
go.mod
Normal file
|
@ -0,0 +1,25 @@
|
|||
module github.com/caos/oidc
|
||||
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/golang/mock v1.3.1
|
||||
github.com/golang/protobuf v1.3.2 // indirect
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/gorilla/mux v1.7.3
|
||||
github.com/gorilla/schema v1.1.0
|
||||
github.com/gorilla/securecookie v1.1.1
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/stretchr/testify v1.4.0
|
||||
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect
|
||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933
|
||||
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect
|
||||
golang.org/x/text v0.3.2
|
||||
google.golang.org/appengine v1.6.5 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.4.0
|
||||
gopkg.in/yaml.v2 v2.2.3 // indirect
|
||||
)
|
74
go.sum
Normal file
74
go.sum
Normal file
|
@ -0,0 +1,74 @@
|
|||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=
|
||||
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
|
||||
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||
github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY=
|
||||
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 h1:anGSYQpPhQwXlwsu5wmfq0nWkCNaMEMUwAv13Y92hd8=
|
||||
golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 h1:e6HwijUxhDe+hPNjZQQn9bA5PW3vNmnN64U2ZW759Lk=
|
||||
golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c h1:HjRaKPaiWks0f5tA6ELVF7ZfqSppfPwOEEAvsrKUTO4=
|
||||
golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 h1:ZBzSG/7F4eNKz2L3GE9o300RX0Az1Bw5HF7PDraD+qU=
|
||||
golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
|
||||
gopkg.in/square/go-jose.v2 v2.4.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.3 h1:fvjTMHxHEw/mxHbtzPi3JCcKXQRAnQTBRo6YCJSVHKI=
|
||||
gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
151
pkg/oidc/authorization.go
Normal file
151
pkg/oidc/authorization.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
const (
|
||||
ScopeOpenID = "openid"
|
||||
|
||||
ResponseTypeCode ResponseType = "code"
|
||||
ResponseTypeIDToken ResponseType = "id_token token"
|
||||
ResponseTypeIDTokenOnly ResponseType = "id_token"
|
||||
|
||||
DisplayPage Display = "page"
|
||||
DisplayPopup Display = "popup"
|
||||
DisplayTouch Display = "touch"
|
||||
DisplayWAP Display = "wap"
|
||||
|
||||
PromptNone Prompt = "none"
|
||||
PromptLogin Prompt = "login"
|
||||
PromptConsent Prompt = "consent"
|
||||
PromptSelectAccount Prompt = "select_account"
|
||||
|
||||
GrantTypeCode GrantType = "authorization_code"
|
||||
|
||||
BearerToken = "Bearer"
|
||||
)
|
||||
|
||||
var displayValues = map[string]Display{
|
||||
"page": DisplayPage,
|
||||
"popup": DisplayPopup,
|
||||
"touch": DisplayTouch,
|
||||
"wap": DisplayWAP,
|
||||
}
|
||||
|
||||
//AuthRequest according to:
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
//
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
Scopes Scopes `schema:"scope"`
|
||||
ResponseType ResponseType `schema:"response_type"`
|
||||
ClientID string `schema:"client_id"`
|
||||
RedirectURI string `schema:"redirect_uri"` //TODO: type
|
||||
|
||||
State string `schema:"state"`
|
||||
|
||||
// ResponseMode TODO: ?
|
||||
|
||||
Nonce string `schema:"nonce"`
|
||||
Display Display `schema:"display"`
|
||||
Prompt Prompt `schema:"prompt"`
|
||||
MaxAge uint32 `schema:"max_age"`
|
||||
UILocales Locales `schema:"ui_locales"`
|
||||
IDTokenHint string `schema:"id_token_hint"`
|
||||
LoginHint string `schema:"login_hint"`
|
||||
ACRValues []string `schema:"acr_values"`
|
||||
|
||||
CodeChallenge string `schema:"code_challenge"`
|
||||
CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"`
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
}
|
||||
func (a *AuthRequest) GetResponseType() ResponseType {
|
||||
return a.ResponseType
|
||||
}
|
||||
func (a *AuthRequest) GetState() string {
|
||||
return a.State
|
||||
}
|
||||
|
||||
type TokenRequest interface {
|
||||
// GrantType GrantType `schema:"grant_type"`
|
||||
GrantType() GrantType
|
||||
}
|
||||
|
||||
type TokenRequestType GrantType
|
||||
|
||||
type AccessTokenRequest struct {
|
||||
Code string `schema:"code"`
|
||||
RedirectURI string `schema:"redirect_uri"`
|
||||
ClientID string `schema:"client_id"`
|
||||
ClientSecret string `schema:"client_secret"`
|
||||
CodeVerifier string `schema:"code_verifier"`
|
||||
}
|
||||
|
||||
func (a *AccessTokenRequest) GrantType() GrantType {
|
||||
return GrantTypeCode
|
||||
}
|
||||
|
||||
type AccessTokenResponse struct {
|
||||
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
|
||||
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
subjectToken string `schema:"subject_token"`
|
||||
subjectTokenType string `schema:"subject_token_type"`
|
||||
actorToken string `schema:"actor_token"`
|
||||
actorTokenType string `schema:"actor_token_type"`
|
||||
resource []string `schema:"resource"`
|
||||
audience []string `schema:"audience"`
|
||||
Scope []string `schema:"scope"`
|
||||
requestedTokenType string `schema:"requested_token_type"`
|
||||
}
|
||||
|
||||
type Scopes []string
|
||||
|
||||
func (s *Scopes) UnmarshalText(text []byte) error {
|
||||
scopes := strings.Split(string(text), " ")
|
||||
*s = Scopes(scopes)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResponseType string
|
||||
|
||||
type Display string
|
||||
|
||||
func (d *Display) UnmarshalText(text []byte) error {
|
||||
var ok bool
|
||||
display := string(text)
|
||||
*d, ok = displayValues[display]
|
||||
if !ok {
|
||||
return errors.New("")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Prompt string
|
||||
|
||||
type Locales []language.Tag
|
||||
|
||||
func (l *Locales) UnmarshalText(text []byte) error {
|
||||
locales := strings.Split(string(text), " ")
|
||||
for _, locale := range locales {
|
||||
tag, err := language.Parse(locale)
|
||||
if err == nil && !tag.IsRoot() {
|
||||
*l = append(*l, tag)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GrantType string
|
33
pkg/oidc/code_challenge.go
Normal file
33
pkg/oidc/code_challenge.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
CodeChallengeMethodPlain CodeChallengeMethod = "plain"
|
||||
CodeChallengeMethodS256 CodeChallengeMethod = "S256"
|
||||
)
|
||||
|
||||
type CodeChallengeMethod string
|
||||
|
||||
type CodeChallenge struct {
|
||||
Challenge string
|
||||
Method CodeChallengeMethod
|
||||
}
|
||||
|
||||
func NewSHACodeChallenge(code string) string {
|
||||
return utils.HashString(sha256.New(), code)
|
||||
}
|
||||
|
||||
func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool {
|
||||
if c == nil {
|
||||
return false //TODO: ?
|
||||
}
|
||||
if c.Method == CodeChallengeMethodS256 {
|
||||
codeVerifier = NewSHACodeChallenge(codeVerifier)
|
||||
}
|
||||
return codeVerifier == c.Challenge
|
||||
}
|
24
pkg/oidc/discovery.go
Normal file
24
pkg/oidc/discovery.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package oidc
|
||||
|
||||
const (
|
||||
DiscoveryEndpoint = "/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
type DiscoveryConfiguration struct {
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
|
||||
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||
IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"`
|
||||
UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
CheckSessionIframe string `json:"check_session_iframe,omitempty"`
|
||||
JwksURI string `json:"jwks_uri,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
|
||||
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
|
||||
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
ClaimsSupported []string `json:"claims_supported,omitempty"`
|
||||
}
|
33
pkg/oidc/grants/client_credentials.go
Normal file
33
pkg/oidc/grants/client_credentials.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package grants
|
||||
|
||||
import "strings"
|
||||
|
||||
type clientCredentialsGrantBasic struct {
|
||||
grantType string `schema:"grant_type"`
|
||||
scope string `schema:"scope"`
|
||||
}
|
||||
|
||||
type clientCredentialsGrant struct {
|
||||
*clientCredentialsGrantBasic
|
||||
clientID string `schema:"client_id"`
|
||||
clientSecret string `schema:"client_secret"`
|
||||
}
|
||||
|
||||
//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant
|
||||
//sneding client_id and client_secret as basic auth header
|
||||
func ClientCredentialsGrantBasic(scopes ...string) *clientCredentialsGrantBasic {
|
||||
return &clientCredentialsGrantBasic{
|
||||
grantType: "client_credentials",
|
||||
scope: strings.Join(scopes, " "),
|
||||
}
|
||||
}
|
||||
|
||||
//ClientCredentialsGrantValues creates an oauth2 `Client Credentials` Grant
|
||||
//sneding client_id and client_secret as form values
|
||||
func ClientCredentialsGrantValues(clientID, clientSecret string, scopes ...string) *clientCredentialsGrant {
|
||||
return &clientCredentialsGrant{
|
||||
clientCredentialsGrantBasic: ClientCredentialsGrantBasic(scopes...),
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
}
|
||||
}
|
75
pkg/oidc/grants/tokenexchange/tokenexchange.go
Normal file
75
pkg/oidc/grants/tokenexchange/tokenexchange.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package tokenexchange
|
||||
|
||||
const (
|
||||
AccessTokenType = "urn:ietf:params:oauth:token-type:access_token"
|
||||
RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
||||
IDTokenType = "urn:ietf:params:oauth:token-type:id_token"
|
||||
JWTTokenType = "urn:ietf:params:oauth:token-type:jwt"
|
||||
DelegationTokenType = AccessTokenType
|
||||
|
||||
TokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
|
||||
)
|
||||
|
||||
type TokenExchangeRequest struct {
|
||||
grantType string `schema:"grant_type"`
|
||||
subjectToken string `schema:"subject_token"`
|
||||
subjectTokenType string `schema:"subject_token_type"`
|
||||
actorToken string `schema:"actor_token"`
|
||||
actorTokenType string `schema:"actor_token_type"`
|
||||
resource []string `schema:"resource"`
|
||||
audience []string `schema:"audience"`
|
||||
scope []string `schema:"scope"`
|
||||
requestedTokenType string `schema:"requested_token_type"`
|
||||
}
|
||||
|
||||
func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest {
|
||||
t := &TokenExchangeRequest{
|
||||
grantType: TokenExchangeGrantType,
|
||||
subjectToken: subjectToken,
|
||||
subjectTokenType: subjectTokenType,
|
||||
requestedTokenType: AccessTokenType,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type TokenExchangeOption func(*TokenExchangeRequest)
|
||||
|
||||
func WithActorToken(token, tokenType string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.actorToken = token
|
||||
req.actorTokenType = tokenType
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudience(audience []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.audience = audience
|
||||
}
|
||||
}
|
||||
|
||||
func WithGrantType(grantType string) TokenExchangeOption {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.grantType = grantType
|
||||
}
|
||||
}
|
||||
|
||||
func WithRequestedTokenType(tokenType string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.requestedTokenType = tokenType
|
||||
}
|
||||
}
|
||||
|
||||
func WithResource(resource []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.resource = resource
|
||||
}
|
||||
}
|
||||
|
||||
func WithScope(scope []string) func(*TokenExchangeRequest) {
|
||||
return func(req *TokenExchangeRequest) {
|
||||
req.scope = scope
|
||||
}
|
||||
}
|
22
pkg/oidc/keyset.go
Normal file
22
pkg/oidc/keyset.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
|
||||
// of JSON web tokens. This is expected to be backed by a remote key set through
|
||||
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
|
||||
type KeySet interface {
|
||||
// VerifySignature parses the JSON web token, verifies the signature, and returns
|
||||
// the raw payload. Header and claim fields are validated by other parts of the
|
||||
// package. For example, the KeySet does not need to check values such as signature
|
||||
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
|
||||
// independently.
|
||||
//
|
||||
// If VerifySignature makes HTTP requests to verify the token, it's expected to
|
||||
// use any HTTP client associated with the context through ClientContext.
|
||||
VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error)
|
||||
}
|
196
pkg/oidc/token.go
Normal file
196
pkg/oidc/token.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
type Tokens struct {
|
||||
*oauth2.Token
|
||||
IDTokenClaims *IDTokenClaims
|
||||
IDToken string
|
||||
}
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
IssuedAt time.Time
|
||||
NotBefore time.Time
|
||||
JWTID string
|
||||
AuthorizedParty string
|
||||
Nonce string
|
||||
AuthTime time.Time
|
||||
CodeHash string
|
||||
AuthenticationContextClassReference string
|
||||
AuthenticationMethodsReferences []string
|
||||
SessionID string
|
||||
Scopes []string
|
||||
ClientID string
|
||||
AccessTokenUseNumber int
|
||||
}
|
||||
|
||||
type IDTokenClaims struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
Audiences []string
|
||||
Expiration time.Time
|
||||
NotBefore time.Time
|
||||
IssuedAt time.Time
|
||||
JWTID string
|
||||
UpdatedAt time.Time
|
||||
AuthorizedParty string
|
||||
Nonce string
|
||||
AuthTime time.Time
|
||||
AccessTokenHash string
|
||||
CodeHash string
|
||||
AuthenticationContextClassReference string
|
||||
AuthenticationMethodsReferences []string
|
||||
ClientID string
|
||||
|
||||
Signature jose.SignatureAlgorithm //TODO: ???
|
||||
}
|
||||
|
||||
type jsonToken struct {
|
||||
Issuer string `json:"iss,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
Audiences []string `json:"aud,omitempty"`
|
||||
Expiration int64 `json:"exp,omitempty"`
|
||||
NotBefore int64 `json:"nbf,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty"`
|
||||
JWTID string `json:"jti,omitempty"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty"`
|
||||
AuthorizedParty string `json:"azp,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
AccessTokenHash string `json:"at_hash,omitempty"`
|
||||
CodeHash string `json:"c_hash,omitempty"`
|
||||
AuthenticationContextClassReference string `json:"acr,omitempty"`
|
||||
AuthenticationMethodsReferences []string `json:"amr,omitempty"`
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
Actor interface{} `json:"act,omitempty"` //TODO: impl
|
||||
Scopes string `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl
|
||||
AccessTokenUseNumber int `json:"at_use_nbr,omitempty"`
|
||||
}
|
||||
|
||||
func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
j := jsonToken{
|
||||
Issuer: t.Issuer,
|
||||
Subject: t.Subject,
|
||||
Audiences: t.Audiences,
|
||||
Expiration: timeToJSON(t.Expiration),
|
||||
NotBefore: timeToJSON(t.NotBefore),
|
||||
IssuedAt: timeToJSON(t.IssuedAt),
|
||||
JWTID: t.JWTID,
|
||||
AuthorizedParty: t.AuthorizedParty,
|
||||
Nonce: t.Nonce,
|
||||
AuthTime: timeToJSON(t.AuthTime),
|
||||
CodeHash: t.CodeHash,
|
||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||
SessionID: t.SessionID,
|
||||
Scopes: strings.Join(t.Scopes, " "),
|
||||
ClientID: t.ClientID,
|
||||
AccessTokenUseNumber: t.AccessTokenUseNumber,
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error {
|
||||
var j jsonToken
|
||||
if err := json.Unmarshal(b, &j); err != nil {
|
||||
return err
|
||||
}
|
||||
audience := j.Audiences
|
||||
if len(audience) == 1 {
|
||||
audience = strings.Split(audience[0], " ")
|
||||
}
|
||||
t.Issuer = j.Issuer
|
||||
t.Subject = j.Subject
|
||||
t.Audiences = audience
|
||||
t.Expiration = time.Unix(j.Expiration, 0).UTC()
|
||||
t.NotBefore = time.Unix(j.NotBefore, 0).UTC()
|
||||
t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC()
|
||||
t.JWTID = j.JWTID
|
||||
t.AuthorizedParty = j.AuthorizedParty
|
||||
t.Nonce = j.Nonce
|
||||
t.AuthTime = time.Unix(j.AuthTime, 0).UTC()
|
||||
t.CodeHash = j.CodeHash
|
||||
t.AuthenticationContextClassReference = j.AuthenticationContextClassReference
|
||||
t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences
|
||||
t.SessionID = j.SessionID
|
||||
t.Scopes = strings.Split(j.Scopes, " ")
|
||||
t.ClientID = j.ClientID
|
||||
t.AccessTokenUseNumber = j.AccessTokenUseNumber
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) MarshalJSON() ([]byte, error) {
|
||||
j := jsonToken{
|
||||
Issuer: t.Issuer,
|
||||
Subject: t.Subject,
|
||||
Audiences: t.Audiences,
|
||||
Expiration: timeToJSON(t.Expiration),
|
||||
NotBefore: timeToJSON(t.NotBefore),
|
||||
IssuedAt: timeToJSON(t.IssuedAt),
|
||||
JWTID: t.JWTID,
|
||||
UpdatedAt: timeToJSON(t.UpdatedAt),
|
||||
AuthorizedParty: t.AuthorizedParty,
|
||||
Nonce: t.Nonce,
|
||||
AuthTime: timeToJSON(t.AuthTime),
|
||||
AccessTokenHash: t.AccessTokenHash,
|
||||
CodeHash: t.CodeHash,
|
||||
AuthenticationContextClassReference: t.AuthenticationContextClassReference,
|
||||
AuthenticationMethodsReferences: t.AuthenticationMethodsReferences,
|
||||
ClientID: t.ClientID,
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
func (t *IDTokenClaims) UnmarshalJSON(b []byte) error {
|
||||
var i jsonToken
|
||||
if err := json.Unmarshal(b, &i); err != nil {
|
||||
return err
|
||||
}
|
||||
audience := i.Audiences
|
||||
if len(audience) == 1 {
|
||||
audience = strings.Split(audience[0], " ")
|
||||
}
|
||||
t.Issuer = i.Issuer
|
||||
t.Subject = i.Subject
|
||||
t.Audiences = audience
|
||||
t.Expiration = time.Unix(i.Expiration, 0).UTC()
|
||||
t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC()
|
||||
t.AuthTime = time.Unix(i.AuthTime, 0).UTC()
|
||||
t.Nonce = i.Nonce
|
||||
t.AuthenticationContextClassReference = i.AuthenticationContextClassReference
|
||||
t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences
|
||||
t.AuthorizedParty = i.AuthorizedParty
|
||||
t.AccessTokenHash = i.AccessTokenHash
|
||||
t.CodeHash = i.CodeHash
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) {
|
||||
hash, err := utils.GetHashAlgorithm(sigAlgorithm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return utils.HashString(hash, claim), nil
|
||||
}
|
||||
|
||||
func timeToJSON(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
120
pkg/oidc/userinfo.go
Normal file
120
pkg/oidc/userinfo.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
type Userinfo struct {
|
||||
Subject string
|
||||
Address *UserinfoAddress
|
||||
UserinfoProfile
|
||||
UserinfoEmail
|
||||
UserinfoPhone
|
||||
|
||||
claims map[string]interface{}
|
||||
}
|
||||
|
||||
type UserinfoPhone struct {
|
||||
PhoneNumber string
|
||||
PhoneNumberVerified bool
|
||||
}
|
||||
type UserinfoProfile struct {
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender Gender
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale language.Tag
|
||||
UpdatedAt time.Time
|
||||
PreferredUsername string
|
||||
}
|
||||
|
||||
type Gender string
|
||||
|
||||
type UserinfoAddress struct {
|
||||
Formatted string
|
||||
StreetAddress string
|
||||
Locality string
|
||||
Region string
|
||||
PostalCode string
|
||||
Country string
|
||||
}
|
||||
|
||||
type UserinfoEmail struct {
|
||||
Email string
|
||||
EmailVerified bool
|
||||
}
|
||||
|
||||
func marshalUserinfoProfile(i UserinfoProfile, claims map[string]interface{}) {
|
||||
claims["name"] = i.Name
|
||||
claims["given_name"] = i.GivenName
|
||||
claims["family_name"] = i.FamilyName
|
||||
claims["middle_name"] = i.MiddleName
|
||||
claims["nickname"] = i.Nickname
|
||||
claims["profile"] = i.Profile
|
||||
claims["picture"] = i.Picture
|
||||
claims["website"] = i.Website
|
||||
claims["gender"] = i.Gender
|
||||
claims["birthdate"] = i.Birthdate
|
||||
claims["Zoneinfo"] = i.Zoneinfo
|
||||
claims["locale"] = i.Locale.String()
|
||||
claims["updated_at"] = i.UpdatedAt.UTC().Unix()
|
||||
claims["preferred_username"] = i.PreferredUsername
|
||||
}
|
||||
|
||||
func marshalUserinfoEmail(i UserinfoEmail, claims map[string]interface{}) {
|
||||
if i.Email != "" {
|
||||
claims["email"] = i.Email
|
||||
}
|
||||
if i.EmailVerified {
|
||||
claims["email_verified"] = i.EmailVerified
|
||||
}
|
||||
}
|
||||
|
||||
func marshalUserinfoAddress(i *UserinfoAddress, claims map[string]interface{}) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
address := make(map[string]interface{})
|
||||
if i.Formatted != "" {
|
||||
address["formatted"] = i.Formatted
|
||||
}
|
||||
if i.StreetAddress != "" {
|
||||
address["street_address"] = i.StreetAddress
|
||||
}
|
||||
claims["address"] = address
|
||||
}
|
||||
|
||||
func marshalUserinfoPhone(i UserinfoPhone, claims map[string]interface{}) {
|
||||
claims["phone_number"] = i.PhoneNumber
|
||||
claims["phone_number_verified"] = i.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (i *Userinfo) MarshalJSON() ([]byte, error) {
|
||||
claims := i.claims
|
||||
if claims == nil {
|
||||
claims = make(map[string]interface{})
|
||||
}
|
||||
claims["sub"] = i.Subject
|
||||
marshalUserinfoAddress(i.Address, claims)
|
||||
marshalUserinfoEmail(i.UserinfoEmail, claims)
|
||||
marshalUserinfoPhone(i.UserinfoPhone, claims)
|
||||
marshalUserinfoProfile(i.UserinfoProfile, claims)
|
||||
return json.Marshal(claims)
|
||||
}
|
||||
|
||||
func (i *Userinfo) UnmmarshalJSON(data []byte) error {
|
||||
if err := json.Unmarshal(data, i); err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, i.claims)
|
||||
}
|
198
pkg/op/authrequest.go
Normal file
198
pkg/op/authrequest.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Authorizer interface {
|
||||
Storage() Storage
|
||||
Decoder() *schema.Decoder
|
||||
Encoder() *schema.Encoder
|
||||
Signer() Signer
|
||||
Crypto() Crypto
|
||||
Issuer() string
|
||||
}
|
||||
|
||||
type ValidationAuthorizer interface {
|
||||
Authorizer
|
||||
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error
|
||||
}
|
||||
|
||||
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
authReq := new(oidc.AuthRequest)
|
||||
err = authorizer.Decoder().Decode(authReq, r.Form)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
validation := ValidateAuthRequest
|
||||
if validater, ok := authorizer.(ValidationAuthorizer); ok {
|
||||
validation = validater.ValidateAuthRequest
|
||||
}
|
||||
if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, req, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
RedirectToLogin(req.GetID(), client, w, r)
|
||||
}
|
||||
|
||||
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error {
|
||||
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil {
|
||||
return err
|
||||
}
|
||||
// if NeedsExistingSession(authReq) {
|
||||
// session, err := storage.CheckSession(authReq.IDTokenHint)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqScopes(scopes []string) error {
|
||||
if len(scopes) == 0 {
|
||||
return ErrInvalidRequest("scope missing")
|
||||
}
|
||||
if !utils.Contains(scopes, oidc.ScopeOpenID) {
|
||||
return ErrInvalidRequest("scope openid missing")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error {
|
||||
if uri == "" {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri must not be empty")
|
||||
}
|
||||
client, err := storage.GetClientByClientID(ctx, client_id)
|
||||
if err != nil {
|
||||
return ErrServerError(err.Error())
|
||||
}
|
||||
if !utils.Contains(client.RedirectURIs(), uri) {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
if strings.HasPrefix(uri, "https://") {
|
||||
return nil
|
||||
}
|
||||
if responseType == oidc.ResponseTypeCode {
|
||||
if strings.HasPrefix(uri, "http://") && IsConfidentialType(client) {
|
||||
return nil
|
||||
}
|
||||
if client.ApplicationType() == ApplicationTypeNative {
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidRequest("redirect_uri not allowed")
|
||||
} else {
|
||||
if client.ApplicationType() != ApplicationTypeNative {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) {
|
||||
return ErrInvalidRequestRedirectURI("redirect_uri not allowed")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateAuthReqResponseType(responseType oidc.ResponseType) error {
|
||||
if responseType == "" {
|
||||
return ErrInvalidRequest("response_type empty")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) {
|
||||
login := client.LoginURL(authReqID)
|
||||
http.Redirect(w, r, login, http.StatusFound)
|
||||
}
|
||||
|
||||
func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
|
||||
params := mux.Vars(r)
|
||||
id := params["id"]
|
||||
|
||||
authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id)
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, nil, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
if !authReq.Done() {
|
||||
AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
AuthResponse(authReq, authorizer, w, r)
|
||||
}
|
||||
|
||||
func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
if authReq.GetResponseType() == oidc.ResponseTypeCode {
|
||||
AuthResponseCode(w, r, authReq, authorizer)
|
||||
return
|
||||
}
|
||||
AuthResponseToken(w, r, authReq, authorizer, client)
|
||||
return
|
||||
}
|
||||
|
||||
func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) {
|
||||
code, err := BuildAuthRequestCode(authReq, authorizer.Crypto())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code)
|
||||
if authReq.GetState() != "" {
|
||||
callback = callback + "&state=" + authReq.GetState()
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) {
|
||||
createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "")
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
|
||||
return
|
||||
}
|
||||
callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params)
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
||||
func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
296
pkg/op/authrequest_test.go
Normal file
296
pkg/op/authrequest_test.go
Normal file
|
@ -0,0 +1,296 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
"github.com/caos/oidc/pkg/op/mock"
|
||||
)
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
// testCallback := func(t *testing.T, clienID string) callbackHandler {
|
||||
// return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
|
||||
// // require.Equal(t, clientID, client.)
|
||||
// }
|
||||
// }
|
||||
// testErr := func(t *testing.T, expected error) errorHandler {
|
||||
// return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// require.Equal(t, expected, err)
|
||||
// }
|
||||
// }
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
authorizer op.Authorizer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"parsing fails",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
&http.Request{Method: "POST", Body: nil},
|
||||
mock.NewAuthorizerExpectValid(t, true),
|
||||
// testCallback(t, ""),
|
||||
// testErr(t, ErrInvalidRequest("cannot parse form")),
|
||||
},
|
||||
},
|
||||
{
|
||||
"decoding fails",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
func() *http.Request {
|
||||
r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return r
|
||||
}(),
|
||||
mock.NewAuthorizerExpectValid(t, true),
|
||||
// testCallback(t, ""),
|
||||
// testErr(t, ErrInvalidRequest("cannot parse auth request")),
|
||||
},
|
||||
},
|
||||
// {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthRequest(t *testing.T) {
|
||||
type args struct {
|
||||
authRequest *oidc.AuthRequest
|
||||
storage op.Storage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
//TODO:
|
||||
// {
|
||||
// "oauth2 spec"
|
||||
// }
|
||||
{
|
||||
"scope missing fails",
|
||||
args{&oidc.AuthRequest{}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"scope openid missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"response_type missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"client_id missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"redirect_uri missing fails",
|
||||
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqScopes(t *testing.T) {
|
||||
type args struct {
|
||||
scopes []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"scopes missing fails", args{}, true,
|
||||
},
|
||||
{
|
||||
"scope openid missing fails", args{[]string{"email"}}, true,
|
||||
},
|
||||
{
|
||||
"scope ok", args{[]string{"openid"}}, false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAuthReqRedirectURI(t *testing.T) {
|
||||
type args struct {
|
||||
uri string
|
||||
clientID string
|
||||
responseType oidc.ResponseType
|
||||
storage op.OPStorage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"empty fails",
|
||||
args{"", "", oidc.ResponseTypeCode, nil},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unregistered fails",
|
||||
args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"storage error fails",
|
||||
args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered http not confidential fails",
|
||||
args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered http confidential ok",
|
||||
args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"code flow registered custom not native fails",
|
||||
args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"code flow registered custom native ok",
|
||||
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered ok",
|
||||
args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered http localhost native ok",
|
||||
args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"implicit flow registered http localhost user agent fails",
|
||||
args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"implicit flow http non localhost fails",
|
||||
args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"implicit flow custom fails",
|
||||
args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := op.ValidateAuthReqRedirectURI(nil, tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectToLogin(t *testing.T) {
|
||||
type args struct {
|
||||
authReqID string
|
||||
client op.Client
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"redirect ok",
|
||||
args{
|
||||
"id",
|
||||
mock.NewClientExpectAny(t, op.ApplicationTypeNative),
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("GET", "/authorize", nil),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, http.StatusFound, rec.Code)
|
||||
require.Equal(t, "/login?id=id", rec.Header().Get("location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeCallback(t *testing.T) {
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
authorizer op.Authorizer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthResponse(t *testing.T) {
|
||||
type args struct {
|
||||
authReq op.AuthRequest
|
||||
authorizer op.Authorizer
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r)
|
||||
})
|
||||
}
|
||||
}
|
33
pkg/op/client.go
Normal file
33
pkg/op/client.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package op
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ApplicationTypeWeb ApplicationType = iota
|
||||
ApplicationTypeUserAgent
|
||||
ApplicationTypeNative
|
||||
|
||||
AccessTokenTypeBearer AccessTokenType = iota
|
||||
AccessTokenTypeJWT
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
GetID() string
|
||||
RedirectURIs() []string
|
||||
ApplicationType() ApplicationType
|
||||
GetAuthMethod() AuthMethod
|
||||
LoginURL(string) string
|
||||
AccessTokenType() AccessTokenType
|
||||
AccessTokenLifetime() time.Duration
|
||||
IDTokenLifetime() time.Duration
|
||||
}
|
||||
|
||||
func IsConfidentialType(c Client) bool {
|
||||
return c.ApplicationType() == ApplicationTypeWeb
|
||||
}
|
||||
|
||||
type ApplicationType int
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
type AccessTokenType int
|
54
pkg/op/config.go
Normal file
54
pkg/op/config.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Configuration interface {
|
||||
Issuer() string
|
||||
AuthorizationEndpoint() Endpoint
|
||||
TokenEndpoint() Endpoint
|
||||
UserinfoEndpoint() Endpoint
|
||||
KeysEndpoint() Endpoint
|
||||
|
||||
AuthMethodPostSupported() bool
|
||||
|
||||
Port() string
|
||||
}
|
||||
|
||||
func ValidateIssuer(issuer string) error {
|
||||
if issuer == "" {
|
||||
return errors.New("missing issuer")
|
||||
}
|
||||
u, err := url.Parse(issuer)
|
||||
if err != nil {
|
||||
return errors.New("invalid url for issuer")
|
||||
}
|
||||
if u.Host == "" {
|
||||
return errors.New("host for issuer missing")
|
||||
}
|
||||
if u.Scheme != "https" {
|
||||
if !devLocalAllowed(u) {
|
||||
return errors.New("scheme for issuer must be `https`")
|
||||
}
|
||||
}
|
||||
if u.Fragment != "" || len(u.Query()) > 0 {
|
||||
return errors.New("no fragments or query allowed for issuer")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func devLocalAllowed(url *url.URL) bool {
|
||||
_, b := os.LookupEnv("CAOS_OIDC_DEV")
|
||||
if !b {
|
||||
return b
|
||||
}
|
||||
return url.Scheme == "http" &&
|
||||
url.Host == "localhost" ||
|
||||
url.Host == "127.0.0.1" ||
|
||||
url.Host == "::1" ||
|
||||
strings.HasPrefix(url.Host, "localhost:")
|
||||
}
|
94
pkg/op/config_test.go
Normal file
94
pkg/op/config_test.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package op
|
||||
|
||||
import "testing"
|
||||
|
||||
import "os"
|
||||
|
||||
func TestValidateIssuer(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"missing issuer fails",
|
||||
args{""},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid url for issuer fails",
|
||||
args{":issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for issuer missing fails",
|
||||
args{"https:///issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host for not https fails",
|
||||
args{"http://issuer.com"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with fragment fails",
|
||||
args{"https://issuer.com/#issuer"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with query fails",
|
||||
args{"https://issuer.com?issuer=me"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"host with https ok",
|
||||
args{"https://issuer.com"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIssuerDevLocalAllowed(t *testing.T) {
|
||||
type args struct {
|
||||
issuer string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"localhost with http ok",
|
||||
args{"http://localhost:9999"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
os.Setenv("CAOS_OIDC_DEV", "")
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
26
pkg/op/crypto.go
Normal file
26
pkg/op/crypto.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Crypto interface {
|
||||
Encrypt(string) (string, error)
|
||||
Decrypt(string) (string, error)
|
||||
}
|
||||
|
||||
type aesCrypto struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func NewAESCrypto(key [32]byte) Crypto {
|
||||
return &aesCrypto{key: string(key[:32])}
|
||||
}
|
||||
|
||||
func (c *aesCrypto) Encrypt(s string) (string, error) {
|
||||
return utils.EncryptAES(s, c.key)
|
||||
}
|
||||
|
||||
func (c *aesCrypto) Decrypt(s string) (string, error) {
|
||||
return utils.DecryptAES(s, c.key)
|
||||
}
|
224
pkg/op/default_op.go
Normal file
224
pkg/op/default_op.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAuthorizationEndpoint = "authorize"
|
||||
defaulTokenEndpoint = "oauth/token"
|
||||
defaultIntrospectEndpoint = "introspect"
|
||||
defaultUserinfoEndpoint = "userinfo"
|
||||
defaultKeysEndpoint = "keys"
|
||||
|
||||
AuthMethodBasic AuthMethod = "client_secret_basic"
|
||||
AuthMethodPost = "client_secret_post"
|
||||
AuthMethodNone = "none"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultEndpoints = &endpoints{
|
||||
Authorization: defaultAuthorizationEndpoint,
|
||||
Token: defaulTokenEndpoint,
|
||||
IntrospectionEndpoint: defaultIntrospectEndpoint,
|
||||
Userinfo: defaultUserinfoEndpoint,
|
||||
JwksURI: defaultKeysEndpoint,
|
||||
}
|
||||
)
|
||||
|
||||
type DefaultOP struct {
|
||||
config *Config
|
||||
endpoints *endpoints
|
||||
discoveryConfig *oidc.DiscoveryConfiguration
|
||||
storage Storage
|
||||
signer Signer
|
||||
crypto Crypto
|
||||
http *http.Server
|
||||
decoder *schema.Decoder
|
||||
encoder *schema.Encoder
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Issuer string
|
||||
CryptoKey [32]byte
|
||||
// ScopesSupported: oidc.SupportedScopes,
|
||||
// ResponseTypesSupported: responseTypes,
|
||||
// GrantTypesSupported: oidc.SupportedGrantTypes,
|
||||
// ClaimsSupported: oidc.SupportedClaims,
|
||||
// IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm},
|
||||
// SubjectTypesSupported: []string{"public"},
|
||||
// TokenEndpointAuthMethodsSupported:
|
||||
Port string
|
||||
}
|
||||
|
||||
type endpoints struct {
|
||||
Authorization Endpoint
|
||||
Token Endpoint
|
||||
IntrospectionEndpoint Endpoint
|
||||
Userinfo Endpoint
|
||||
EndSessionEndpoint Endpoint
|
||||
CheckSessionIframe Endpoint
|
||||
JwksURI Endpoint
|
||||
}
|
||||
|
||||
type DefaultOPOpts func(o *DefaultOP) error
|
||||
|
||||
func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Authorization = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Token = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts {
|
||||
return func(o *DefaultOP) error {
|
||||
if err := endpoint.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.endpoints.Userinfo = endpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) {
|
||||
err := ValidateIssuer(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &DefaultOP{
|
||||
config: config,
|
||||
storage: storage,
|
||||
endpoints: DefaultEndpoints,
|
||||
}
|
||||
|
||||
p.signer, err = NewDefaultSigner(ctx, storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, optFunc := range opOpts {
|
||||
if err := optFunc(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
p.discoveryConfig = CreateDiscoveryConfig(p, p.signer)
|
||||
|
||||
router := CreateRouter(p)
|
||||
p.http = &http.Server{
|
||||
Addr: ":" + config.Port,
|
||||
Handler: router,
|
||||
}
|
||||
p.decoder = schema.NewDecoder()
|
||||
p.decoder.IgnoreUnknownKeys(true)
|
||||
|
||||
p.encoder = schema.NewEncoder()
|
||||
|
||||
p.crypto = NewAESCrypto(config.CryptoKey)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Issuer() string {
|
||||
return p.config.Issuer
|
||||
}
|
||||
|
||||
func (p *DefaultOP) AuthorizationEndpoint() Endpoint {
|
||||
return p.endpoints.Authorization
|
||||
}
|
||||
|
||||
func (p *DefaultOP) TokenEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.Token)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) UserinfoEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.Userinfo)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) KeysEndpoint() Endpoint {
|
||||
return Endpoint(p.endpoints.JwksURI)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) AuthMethodPostSupported() bool {
|
||||
return true //TODO: config
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Port() string {
|
||||
return p.config.Port
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HttpHandler() *http.Server {
|
||||
return p.http
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
Discover(w, p.discoveryConfig)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Decoder() *schema.Decoder {
|
||||
return p.decoder
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Encoder() *schema.Encoder {
|
||||
return p.encoder
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Storage() Storage {
|
||||
return p.storage
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Signer() Signer {
|
||||
return p.signer
|
||||
}
|
||||
|
||||
func (p *DefaultOP) Crypto() Crypto {
|
||||
return p.crypto
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) {
|
||||
Keys(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||
Authorize(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) {
|
||||
AuthorizeCallback(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) {
|
||||
reqType := r.FormValue("grant_type")
|
||||
if reqType == "" {
|
||||
ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing"))
|
||||
return
|
||||
}
|
||||
if reqType == string(oidc.GrantTypeCode) {
|
||||
CodeExchange(w, r, p)
|
||||
return
|
||||
}
|
||||
TokenExchange(w, r, p)
|
||||
}
|
||||
|
||||
func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) {
|
||||
Userinfo(w, r, p)
|
||||
}
|
49
pkg/op/default_op_test.go
Normal file
49
pkg/op/default_op_test.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func TestDefaultOP_HandleDiscovery(t *testing.T) {
|
||||
type fields struct {
|
||||
config *Config
|
||||
endpoints *endpoints
|
||||
discoveryConfig *oidc.DiscoveryConfiguration
|
||||
storage Storage
|
||||
http *http.Server
|
||||
}
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"OK", fields{config: nil, endpoints: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"https://issuer.com"}`, 200},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &DefaultOP{
|
||||
config: tt.fields.config,
|
||||
endpoints: tt.fields.endpoints,
|
||||
discoveryConfig: tt.fields.discoveryConfig,
|
||||
storage: tt.fields.storage,
|
||||
http: tt.fields.http,
|
||||
}
|
||||
p.HandleDiscovery(tt.args.w, tt.args.r)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, tt.want, rec.Body.String())
|
||||
require.Equal(t, tt.wantCode, rec.Code)
|
||||
})
|
||||
}
|
||||
}
|
119
pkg/op/discovery.go
Normal file
119
pkg/op/discovery.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) {
|
||||
utils.MarshalJSON(w, config)
|
||||
}
|
||||
|
||||
func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration {
|
||||
return &oidc.DiscoveryConfiguration{
|
||||
Issuer: c.Issuer(),
|
||||
AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()),
|
||||
TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()),
|
||||
// IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()),
|
||||
UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()),
|
||||
// EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint),
|
||||
// CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe),
|
||||
JwksURI: c.KeysEndpoint().Absolute(c.Issuer()),
|
||||
ScopesSupported: Scopes(c),
|
||||
ResponseTypesSupported: ResponseTypes(c),
|
||||
GrantTypesSupported: GrantTypes(c),
|
||||
ClaimsSupported: SupportedClaims(c),
|
||||
IDTokenSigningAlgValuesSupported: SigAlgorithms(s),
|
||||
SubjectTypesSupported: SubjectTypes(c),
|
||||
TokenEndpointAuthMethodsSupported: AuthMethods(c),
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ScopeOpenID = "openid"
|
||||
ScopeProfile = "profile"
|
||||
ScopeEmail = "email"
|
||||
ScopePhone = "phone"
|
||||
ScopeAddress = "address"
|
||||
)
|
||||
|
||||
var DefaultSupportedScopes = []string{
|
||||
ScopeOpenID,
|
||||
ScopeProfile,
|
||||
ScopeEmail,
|
||||
ScopePhone,
|
||||
ScopeAddress,
|
||||
}
|
||||
|
||||
func Scopes(c Configuration) []string {
|
||||
return DefaultSupportedScopes //TODO: config
|
||||
}
|
||||
|
||||
func ResponseTypes(c Configuration) []string {
|
||||
return []string{
|
||||
"code",
|
||||
"id_token",
|
||||
// "code token",
|
||||
// "code id_token",
|
||||
"id_token token",
|
||||
// "code id_token token"
|
||||
}
|
||||
}
|
||||
|
||||
func GrantTypes(c Configuration) []string {
|
||||
return []string{
|
||||
"client_credentials",
|
||||
"authorization_code",
|
||||
// "password",
|
||||
"urn:ietf:params:oauth:grant-type:token-exchange",
|
||||
}
|
||||
}
|
||||
|
||||
func SupportedClaims(c Configuration) []string {
|
||||
return []string{ //TODO: config
|
||||
"sub",
|
||||
"aud",
|
||||
"exp",
|
||||
"iat",
|
||||
"iss",
|
||||
"auth_time",
|
||||
"nonce",
|
||||
"acr",
|
||||
"amr",
|
||||
"c_hash",
|
||||
"at_hash",
|
||||
"act",
|
||||
"scopes",
|
||||
"client_id",
|
||||
"azp",
|
||||
"preferred_username",
|
||||
"name",
|
||||
"family_name",
|
||||
"given_name",
|
||||
"locale",
|
||||
"email",
|
||||
"email_verified",
|
||||
"phone_number",
|
||||
"phone_number_verified",
|
||||
}
|
||||
}
|
||||
|
||||
func SigAlgorithms(s Signer) []string {
|
||||
return []string{string(s.SignatureAlgorithm())}
|
||||
}
|
||||
|
||||
func SubjectTypes(c Configuration) []string {
|
||||
return []string{"public"} //TODO: config
|
||||
}
|
||||
|
||||
func AuthMethods(c Configuration) []string {
|
||||
authMethods := []string{
|
||||
string(AuthMethodBasic),
|
||||
}
|
||||
if c.AuthMethodPostSupported() {
|
||||
authMethods = append(authMethods, string(AuthMethodPost))
|
||||
}
|
||||
return authMethods
|
||||
}
|
235
pkg/op/discovery_test.go
Normal file
235
pkg/op/discovery_test.go
Normal file
|
@ -0,0 +1,235 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
"github.com/caos/oidc/pkg/op/mock"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func TestDiscover(t *testing.T) {
|
||||
type args struct {
|
||||
w http.ResponseWriter
|
||||
config *oidc.DiscoveryConfiguration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"OK",
|
||||
args{
|
||||
httptest.NewRecorder(),
|
||||
&oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
op.Discover(tt.args.w, tt.args.config)
|
||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDiscoveryConfig(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
s op.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *oidc.DiscoveryConfiguration
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_scopes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"default Scopes",
|
||||
args{},
|
||||
op.DefaultSupportedScopes,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("scopes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ResponseTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("responseTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GrantTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("grantTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedClaims(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SupportedClaims() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SigAlgorithms(t *testing.T) {
|
||||
m := mock.NewMockSigner(gomock.NewController((t)))
|
||||
type args struct {
|
||||
s op.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
args{func() op.Signer {
|
||||
m.EXPECT().SignatureAlgorithm().Return(jose.RS256)
|
||||
return m
|
||||
}()},
|
||||
[]string{"RS256"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SubjectTypes(t *testing.T) {
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"none",
|
||||
args{},
|
||||
[]string{"public"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("subjectTypes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AuthMethods(t *testing.T) {
|
||||
m := mock.NewMockConfiguration(gomock.NewController((t)))
|
||||
type args struct {
|
||||
c op.Configuration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"imlicit basic",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPostSupported().Return(false)
|
||||
return m
|
||||
}()},
|
||||
[]string{string(op.AuthMethodBasic)},
|
||||
},
|
||||
{
|
||||
"basic and post",
|
||||
args{func() op.Configuration {
|
||||
m.EXPECT().AuthMethodPostSupported().Return(true)
|
||||
return m
|
||||
}()},
|
||||
[]string{string(op.AuthMethodBasic), string(op.AuthMethodPost)},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := op.AuthMethods(tt.args.c); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("authMethods() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
25
pkg/op/endpoint.go
Normal file
25
pkg/op/endpoint.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package op
|
||||
|
||||
import "strings"
|
||||
|
||||
type Endpoint string
|
||||
|
||||
func (e Endpoint) Relative() string {
|
||||
return relativeEndpoint(string(e))
|
||||
}
|
||||
|
||||
func (e Endpoint) Absolute(host string) string {
|
||||
return absoluteEndpoint(host, string(e))
|
||||
}
|
||||
|
||||
func (e Endpoint) Validate() error {
|
||||
return nil //TODO:
|
||||
}
|
||||
|
||||
func absoluteEndpoint(host, endpoint string) string {
|
||||
return strings.TrimSuffix(host, "/") + relativeEndpoint(endpoint)
|
||||
}
|
||||
|
||||
func relativeEndpoint(endpoint string) string {
|
||||
return "/" + strings.TrimPrefix(endpoint, "/")
|
||||
}
|
95
pkg/op/endpoint_test.go
Normal file
95
pkg/op/endpoint_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package op_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func TestEndpoint_Relative(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"without starting /",
|
||||
op.Endpoint("test"),
|
||||
"/test",
|
||||
},
|
||||
{
|
||||
"with starting /",
|
||||
op.Endpoint("/test"),
|
||||
"/test",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.e.Relative(); got != tt.want {
|
||||
t.Errorf("Endpoint.Relative() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpoint_Absolute(t *testing.T) {
|
||||
type args struct {
|
||||
host string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"no /",
|
||||
op.Endpoint("test"),
|
||||
args{"https://host"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"endpoint without /",
|
||||
op.Endpoint("test"),
|
||||
args{"https://host/"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"host without /",
|
||||
op.Endpoint("/test"),
|
||||
args{"https://host"},
|
||||
"https://host/test",
|
||||
},
|
||||
{
|
||||
"both /",
|
||||
op.Endpoint("/test"),
|
||||
args{"https://host/"},
|
||||
"https://host/test",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.e.Absolute(tt.args.host); got != tt.want {
|
||||
t.Errorf("Endpoint.Absolute() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: impl test
|
||||
func TestEndpoint_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
e op.Endpoint
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.e.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
99
pkg/op/error.go
Normal file
99
pkg/op/error.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
InvalidRequest errorType = "invalid_request"
|
||||
ServerError errorType = "server_error"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidRequest = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
ErrInvalidRequestRedirectURI = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: InvalidRequest,
|
||||
Description: description,
|
||||
redirectDisabled: true,
|
||||
}
|
||||
}
|
||||
ErrServerError = func(description string) *OAuthError {
|
||||
return &OAuthError{
|
||||
ErrorType: ServerError,
|
||||
Description: description,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
type errorType string
|
||||
|
||||
type ErrAuthRequest interface {
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetState() string
|
||||
}
|
||||
|
||||
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) {
|
||||
if authReq == nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
e.state = authReq.GetState()
|
||||
if authReq.GetRedirectURI() == "" || e.redirectDisabled {
|
||||
http.Error(w, e.Description, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
params, err := utils.URLEncodeResponse(e, encoder)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
url := authReq.GetRedirectURI()
|
||||
responseType := authReq.GetResponseType()
|
||||
if responseType == "" || responseType == oidc.ResponseTypeCode {
|
||||
url += "?" + params
|
||||
} else {
|
||||
url += "#" + params
|
||||
}
|
||||
http.Redirect(w, r, url, http.StatusFound)
|
||||
}
|
||||
|
||||
func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
e, ok := err.(*OAuthError)
|
||||
if !ok {
|
||||
e = new(OAuthError)
|
||||
e.ErrorType = ServerError
|
||||
e.Description = err.Error()
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
utils.MarshalJSON(w, e)
|
||||
}
|
||||
|
||||
type OAuthError struct {
|
||||
ErrorType errorType `json:"error" schema:"error"`
|
||||
Description string `json:"description" schema:"description"`
|
||||
state string `json:"state" schema:"state"`
|
||||
redirectDisabled bool
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.ErrorType, e.Description)
|
||||
}
|
19
pkg/op/keys.go
Normal file
19
pkg/op/keys.go
Normal file
|
@ -0,0 +1,19 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type KeyProvider interface {
|
||||
Storage() Storage
|
||||
}
|
||||
|
||||
func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) {
|
||||
keySet, err := k.Storage().GetKeySet(r.Context())
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
utils.MarshalJSON(w, keySet)
|
||||
}
|
119
pkg/op/mock/authorizer.mock.go
Normal file
119
pkg/op/mock/authorizer.mock.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Authorizer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
schema "github.com/gorilla/schema"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockAuthorizer is a mock of Authorizer interface
|
||||
type MockAuthorizer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAuthorizerMockRecorder
|
||||
}
|
||||
|
||||
// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer
|
||||
type MockAuthorizerMockRecorder struct {
|
||||
mock *MockAuthorizer
|
||||
}
|
||||
|
||||
// NewMockAuthorizer creates a new mock instance
|
||||
func NewMockAuthorizer(ctrl *gomock.Controller) *MockAuthorizer {
|
||||
mock := &MockAuthorizer{ctrl: ctrl}
|
||||
mock.recorder = &MockAuthorizerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Crypto mocks base method
|
||||
func (m *MockAuthorizer) Crypto() op.Crypto {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Crypto")
|
||||
ret0, _ := ret[0].(op.Crypto)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Crypto indicates an expected call of Crypto
|
||||
func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Crypto", reflect.TypeOf((*MockAuthorizer)(nil).Crypto))
|
||||
}
|
||||
|
||||
// Decoder mocks base method
|
||||
func (m *MockAuthorizer) Decoder() *schema.Decoder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Decoder")
|
||||
ret0, _ := ret[0].(*schema.Decoder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Decoder indicates an expected call of Decoder
|
||||
func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decoder", reflect.TypeOf((*MockAuthorizer)(nil).Decoder))
|
||||
}
|
||||
|
||||
// Encoder mocks base method
|
||||
func (m *MockAuthorizer) Encoder() *schema.Encoder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Encoder")
|
||||
ret0, _ := ret[0].(*schema.Encoder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Encoder indicates an expected call of Encoder
|
||||
func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder))
|
||||
}
|
||||
|
||||
// Issuer mocks base method
|
||||
func (m *MockAuthorizer) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer
|
||||
func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer))
|
||||
}
|
||||
|
||||
// Signer mocks base method
|
||||
func (m *MockAuthorizer) Signer() op.Signer {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Signer")
|
||||
ret0, _ := ret[0].(op.Signer)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Signer indicates an expected call of Signer
|
||||
func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer))
|
||||
}
|
||||
|
||||
// Storage mocks base method
|
||||
func (m *MockAuthorizer) Storage() op.Storage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Storage")
|
||||
ret0, _ := ret[0].(op.Storage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Storage indicates an expected call of Storage
|
||||
func (mr *MockAuthorizerMockRecorder) Storage() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockAuthorizer)(nil).Storage))
|
||||
}
|
89
pkg/op/mock/authorizer.mock.impl.go
Normal file
89
pkg/op/mock/authorizer.mock.impl.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/schema"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewAuthorizer(t *testing.T) op.Authorizer {
|
||||
return NewMockAuthorizer(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer {
|
||||
m := NewAuthorizer(t)
|
||||
ExpectDecoder(m)
|
||||
ExpectEncoder(m)
|
||||
ExpectSigner(m, t)
|
||||
ExpectStorage(m, t)
|
||||
// ExpectErrorHandler(m, t, wantErr)
|
||||
return m
|
||||
}
|
||||
|
||||
// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer {
|
||||
// m := NewAuthorizer(t)
|
||||
// ExpectDecoderFails(m)
|
||||
// ExpectEncoder(m)
|
||||
// ExpectSigner(m, t)
|
||||
// ExpectStorage(m, t)
|
||||
// ExpectErrorHandler(m, t)
|
||||
// return m
|
||||
// }
|
||||
|
||||
func ExpectDecoder(a op.Authorizer) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder())
|
||||
}
|
||||
|
||||
func ExpectEncoder(a op.Authorizer) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder())
|
||||
}
|
||||
|
||||
func ExpectSigner(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Signer().DoAndReturn(
|
||||
func() op.Signer {
|
||||
return &Sig{}
|
||||
})
|
||||
}
|
||||
|
||||
// func ExpectErrorHandler(a op.Authorizer, t *testing.T, wantErr bool) {
|
||||
// mockA := a.(*MockAuthorizer)
|
||||
// mockA.EXPECT().ErrorHandler().AnyTimes().
|
||||
// Return(func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
|
||||
// if wantErr {
|
||||
// require.Error(t, err)
|
||||
// return
|
||||
// }
|
||||
// require.NoError(t, err)
|
||||
// })
|
||||
// }
|
||||
|
||||
type Sig struct{}
|
||||
|
||||
func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return jose.HS256
|
||||
}
|
||||
|
||||
func ExpectStorage(a op.Authorizer, t *testing.T) {
|
||||
mockA := a.(*MockAuthorizer)
|
||||
mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t))
|
||||
}
|
||||
|
||||
// func NewMockSignerAny(t *testing.T) op.Signer {
|
||||
// m := NewMockSigner(gomock.NewController(t))
|
||||
// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil)
|
||||
// return m
|
||||
// }
|
29
pkg/op/mock/client.go
Normal file
29
pkg/op/mock/client.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewClient(t *testing.T) op.Client {
|
||||
return NewMockClient(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client {
|
||||
c := NewClient(t)
|
||||
m := c.(*MockClient)
|
||||
m.EXPECT().RedirectURIs().AnyTimes().Return([]string{
|
||||
"https://registered.com/callback",
|
||||
"http://registered.com/callback",
|
||||
"http://localhost:9999/callback",
|
||||
"custom://callback"})
|
||||
m.EXPECT().ApplicationType().AnyTimes().Return(appType)
|
||||
m.EXPECT().LoginURL(gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(id string) string {
|
||||
return "login?id=" + id
|
||||
})
|
||||
return c
|
||||
}
|
147
pkg/op/mock/client.mock.go
Normal file
147
pkg/op/mock/client.mock.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Client)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
)
|
||||
|
||||
// MockClient is a mock of Client interface
|
||||
type MockClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientMockRecorder
|
||||
}
|
||||
|
||||
// MockClientMockRecorder is the mock recorder for MockClient
|
||||
type MockClientMockRecorder struct {
|
||||
mock *MockClient
|
||||
}
|
||||
|
||||
// NewMockClient creates a new mock instance
|
||||
func NewMockClient(ctrl *gomock.Controller) *MockClient {
|
||||
mock := &MockClient{ctrl: ctrl}
|
||||
mock.recorder = &MockClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AccessTokenLifetime mocks base method
|
||||
func (m *MockClient) AccessTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenLifetime indicates an expected call of AccessTokenLifetime
|
||||
func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime))
|
||||
}
|
||||
|
||||
// AccessTokenType mocks base method
|
||||
func (m *MockClient) AccessTokenType() op.AccessTokenType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AccessTokenType")
|
||||
ret0, _ := ret[0].(op.AccessTokenType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AccessTokenType indicates an expected call of AccessTokenType
|
||||
func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType))
|
||||
}
|
||||
|
||||
// ApplicationType mocks base method
|
||||
func (m *MockClient) ApplicationType() op.ApplicationType {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ApplicationType")
|
||||
ret0, _ := ret[0].(op.ApplicationType)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ApplicationType indicates an expected call of ApplicationType
|
||||
func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType))
|
||||
}
|
||||
|
||||
// GetAuthMethod mocks base method
|
||||
func (m *MockClient) GetAuthMethod() op.AuthMethod {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthMethod")
|
||||
ret0, _ := ret[0].(op.AuthMethod)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthMethod indicates an expected call of GetAuthMethod
|
||||
func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod))
|
||||
}
|
||||
|
||||
// GetID mocks base method
|
||||
func (m *MockClient) GetID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetID indicates an expected call of GetID
|
||||
func (mr *MockClientMockRecorder) GetID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID))
|
||||
}
|
||||
|
||||
// IDTokenLifetime mocks base method
|
||||
func (m *MockClient) IDTokenLifetime() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IDTokenLifetime")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IDTokenLifetime indicates an expected call of IDTokenLifetime
|
||||
func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime))
|
||||
}
|
||||
|
||||
// LoginURL mocks base method
|
||||
func (m *MockClient) LoginURL(arg0 string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoginURL", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LoginURL indicates an expected call of LoginURL
|
||||
func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0)
|
||||
}
|
||||
|
||||
// RedirectURIs mocks base method
|
||||
func (m *MockClient) RedirectURIs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RedirectURIs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RedirectURIs indicates an expected call of RedirectURIs
|
||||
func (mr *MockClientMockRecorder) RedirectURIs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockClient)(nil).RedirectURIs))
|
||||
}
|
132
pkg/op/mock/configuration.mock.go
Normal file
132
pkg/op/mock/configuration.mock.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Configuration)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockConfiguration is a mock of Configuration interface
|
||||
type MockConfiguration struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfigurationMockRecorder
|
||||
}
|
||||
|
||||
// MockConfigurationMockRecorder is the mock recorder for MockConfiguration
|
||||
type MockConfigurationMockRecorder struct {
|
||||
mock *MockConfiguration
|
||||
}
|
||||
|
||||
// NewMockConfiguration creates a new mock instance
|
||||
func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration {
|
||||
mock := &MockConfiguration{ctrl: ctrl}
|
||||
mock.recorder = &MockConfigurationMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AuthMethodPostSupported mocks base method
|
||||
func (m *MockConfiguration) AuthMethodPostSupported() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthMethodPostSupported")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported
|
||||
func (mr *MockConfigurationMockRecorder) AuthMethodPostSupported() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPostSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPostSupported))
|
||||
}
|
||||
|
||||
// AuthorizationEndpoint mocks base method
|
||||
func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizationEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint))
|
||||
}
|
||||
|
||||
// Issuer mocks base method
|
||||
func (m *MockConfiguration) Issuer() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Issuer")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Issuer indicates an expected call of Issuer
|
||||
func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer))
|
||||
}
|
||||
|
||||
// KeysEndpoint mocks base method
|
||||
func (m *MockConfiguration) KeysEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeysEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// KeysEndpoint indicates an expected call of KeysEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint))
|
||||
}
|
||||
|
||||
// Port mocks base method
|
||||
func (m *MockConfiguration) Port() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Port")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Port indicates an expected call of Port
|
||||
func (mr *MockConfigurationMockRecorder) Port() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockConfiguration)(nil).Port))
|
||||
}
|
||||
|
||||
// TokenEndpoint mocks base method
|
||||
func (m *MockConfiguration) TokenEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TokenEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TokenEndpoint indicates an expected call of TokenEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint))
|
||||
}
|
||||
|
||||
// UserinfoEndpoint mocks base method
|
||||
func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UserinfoEndpoint")
|
||||
ret0, _ := ret[0].(op.Endpoint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UserinfoEndpoint indicates an expected call of UserinfoEndpoint
|
||||
func (mr *MockConfigurationMockRecorder) UserinfoEndpoint() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserinfoEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserinfoEndpoint))
|
||||
}
|
7
pkg/op/mock/generate.go
Normal file
7
pkg/op/mock/generate.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./storage.mock.go github.com/caos/oidc/pkg/op Storage
|
||||
//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer
|
||||
//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client
|
||||
//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration
|
||||
//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer
|
79
pkg/op/mock/signer.mock.go
Normal file
79
pkg/op/mock/signer.mock.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Signer)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockSigner is a mock of Signer interface
|
||||
type MockSigner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSignerMockRecorder
|
||||
}
|
||||
|
||||
// MockSignerMockRecorder is the mock recorder for MockSigner
|
||||
type MockSignerMockRecorder struct {
|
||||
mock *MockSigner
|
||||
}
|
||||
|
||||
// NewMockSigner creates a new mock instance
|
||||
func NewMockSigner(ctrl *gomock.Controller) *MockSigner {
|
||||
mock := &MockSigner{ctrl: ctrl}
|
||||
mock.recorder = &MockSignerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockSigner) EXPECT() *MockSignerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// SignAccessToken mocks base method
|
||||
func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignAccessToken", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignAccessToken indicates an expected call of SignAccessToken
|
||||
func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0)
|
||||
}
|
||||
|
||||
// SignIDToken mocks base method
|
||||
func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignIDToken", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SignIDToken indicates an expected call of SignIDToken
|
||||
func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0)
|
||||
}
|
||||
|
||||
// SignatureAlgorithm mocks base method
|
||||
func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SignatureAlgorithm")
|
||||
ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SignatureAlgorithm indicates an expected call of SignatureAlgorithm
|
||||
func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm))
|
||||
}
|
170
pkg/op/mock/storage.mock.go
Normal file
170
pkg/op/mock/storage.mock.go
Normal file
|
@ -0,0 +1,170 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/oidc/pkg/op (interfaces: Storage)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
oidc "github.com/caos/oidc/pkg/oidc"
|
||||
op "github.com/caos/oidc/pkg/op"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
go_jose_v2 "gopkg.in/square/go-jose.v2"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockStorage is a mock of Storage interface
|
||||
type MockStorage struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStorageMockRecorder
|
||||
}
|
||||
|
||||
// MockStorageMockRecorder is the mock recorder for MockStorage
|
||||
type MockStorageMockRecorder struct {
|
||||
mock *MockStorage
|
||||
}
|
||||
|
||||
// NewMockStorage creates a new mock instance
|
||||
func NewMockStorage(ctrl *gomock.Controller) *MockStorage {
|
||||
mock := &MockStorage{ctrl: ctrl}
|
||||
mock.recorder = &MockStorageMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockStorage) EXPECT() *MockStorageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AuthRequestByID mocks base method
|
||||
func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AuthRequestByID indicates an expected call of AuthRequestByID
|
||||
func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret mocks base method
|
||||
func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret
|
||||
func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// CreateAuthRequest mocks base method
|
||||
func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateAuthRequest indicates an expected call of CreateAuthRequest
|
||||
func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteAuthRequest mocks base method
|
||||
func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest
|
||||
func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetClientByClientID mocks base method
|
||||
func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1)
|
||||
ret0, _ := ret[0].(op.Client)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetClientByClientID indicates an expected call of GetClientByClientID
|
||||
func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetKeySet mocks base method
|
||||
func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetKeySet", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetKeySet indicates an expected call of GetKeySet
|
||||
func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0)
|
||||
}
|
||||
|
||||
// GetSigningKey mocks base method
|
||||
func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSigningKey", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetSigningKey indicates an expected call of GetSigningKey
|
||||
func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0)
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes mocks base method
|
||||
func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oidc.Userinfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes
|
||||
func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveKeyPair mocks base method
|
||||
func (m *MockStorage) SaveKeyPair(arg0 context.Context) (*go_jose_v2.SigningKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveKeyPair", arg0)
|
||||
ret0, _ := ret[0].(*go_jose_v2.SigningKey)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SaveKeyPair indicates an expected call of SaveKeyPair
|
||||
func (mr *MockStorageMockRecorder) SaveKeyPair(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveKeyPair), arg0)
|
||||
}
|
142
pkg/op/mock/storage.mock.impl.go
Normal file
142
pkg/op/mock/storage.mock.impl.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/caos/oidc/pkg/op"
|
||||
)
|
||||
|
||||
func NewStorage(t *testing.T) op.Storage {
|
||||
return NewMockStorage(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func NewMockStorageExpectValidClientID(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectValidClientID(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageExpectInvalidClientID(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectInvalidClientID(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageAny(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
mockS := m.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil)
|
||||
mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageSigningKeyError(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKeyError(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKeyInvalid(m)
|
||||
return m
|
||||
}
|
||||
func NewMockStorageSigningKey(t *testing.T) op.Storage {
|
||||
m := NewStorage(t)
|
||||
ExpectSigningKey(m)
|
||||
return m
|
||||
}
|
||||
|
||||
func ExpectInvalidClientID(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).Return(nil, errors.New("client not found"))
|
||||
}
|
||||
|
||||
func ExpectValidClientID(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, id string) (op.Client, error) {
|
||||
var appType op.ApplicationType
|
||||
var authMethod op.AuthMethod
|
||||
var accessTokenType op.AccessTokenType
|
||||
switch id {
|
||||
case "web_client":
|
||||
appType = op.ApplicationTypeWeb
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "native_client":
|
||||
appType = op.ApplicationTypeNative
|
||||
authMethod = op.AuthMethodNone
|
||||
accessTokenType = op.AccessTokenTypeBearer
|
||||
case "useragent_client":
|
||||
appType = op.ApplicationTypeUserAgent
|
||||
authMethod = op.AuthMethodBasic
|
||||
accessTokenType = op.AccessTokenTypeJWT
|
||||
}
|
||||
return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func ExpectSigningKeyError(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(nil, errors.New("error"))
|
||||
}
|
||||
|
||||
func ExpectSigningKeyInvalid(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{}, nil)
|
||||
}
|
||||
|
||||
func ExpectSigningKey(s op.Storage) {
|
||||
mockS := s.(*MockStorage)
|
||||
mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil)
|
||||
}
|
||||
|
||||
type ConfClient struct {
|
||||
id string
|
||||
appType op.ApplicationType
|
||||
authMethod op.AuthMethod
|
||||
accessTokenType op.AccessTokenType
|
||||
}
|
||||
|
||||
func (c *ConfClient) RedirectURIs() []string {
|
||||
return []string{
|
||||
"https://registered.com/callback",
|
||||
"http://registered.com/callback",
|
||||
"http://localhost:9999/callback",
|
||||
"custom://callback",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConfClient) LoginURL(id string) string {
|
||||
return "login?id=" + id
|
||||
}
|
||||
|
||||
func (c *ConfClient) ApplicationType() op.ApplicationType {
|
||||
return c.appType
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetAuthMethod() op.AuthMethod {
|
||||
return c.authMethod
|
||||
}
|
||||
|
||||
func (c *ConfClient) GetID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *ConfClient) AccessTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) IDTokenLifetime() time.Duration {
|
||||
return time.Duration(5 * time.Minute)
|
||||
}
|
||||
func (c *ConfClient) AccessTokenType() op.AccessTokenType {
|
||||
return c.accessTokenType
|
||||
}
|
51
pkg/op/op.go
Normal file
51
pkg/op/op.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type OpenIDProvider interface {
|
||||
Configuration
|
||||
HandleDiscovery(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorize(w http.ResponseWriter, r *http.Request)
|
||||
HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request)
|
||||
HandleExchange(w http.ResponseWriter, r *http.Request)
|
||||
HandleUserinfo(w http.ResponseWriter, r *http.Request)
|
||||
HandleKeys(w http.ResponseWriter, r *http.Request)
|
||||
HttpHandler() *http.Server
|
||||
}
|
||||
|
||||
func CreateRouter(o OpenIDProvider) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery)
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize)
|
||||
router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", o.HandleAuthorizeCallback)
|
||||
router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange)
|
||||
router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo)
|
||||
router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys)
|
||||
return router
|
||||
}
|
||||
|
||||
func Start(ctx context.Context, o OpenIDProvider) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := o.HttpHandler().Shutdown(ctx)
|
||||
if err != nil {
|
||||
logrus.Error("graceful shutdown of oidc server failed")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := o.HttpHandler().ListenAndServe()
|
||||
if err != nil {
|
||||
logrus.Panicf("oidc server serve failed: %v", err)
|
||||
}
|
||||
}()
|
||||
logrus.Infof("oidc server is listening on %s", o.Port())
|
||||
}
|
13
pkg/op/session.go
Normal file
13
pkg/op/session.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package op
|
||||
|
||||
import "github.com/caos/oidc/pkg/oidc"
|
||||
|
||||
func NeedsExistingSession(authRequest *oidc.AuthRequest) bool {
|
||||
if authRequest == nil {
|
||||
return true
|
||||
}
|
||||
if authRequest.Prompt == oidc.PromptNone {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
78
pkg/op/signer.go
Normal file
78
pkg/op/signer.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type Signer interface {
|
||||
SignIDToken(claims *oidc.IDTokenClaims) (string, error)
|
||||
SignAccessToken(claims *oidc.AccessTokenClaims) (string, error)
|
||||
SignatureAlgorithm() jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
type idTokenSigner struct {
|
||||
signer jose.Signer
|
||||
storage AuthStorage
|
||||
algorithm jose.SignatureAlgorithm
|
||||
}
|
||||
|
||||
func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) {
|
||||
s := &idTokenSigner{
|
||||
storage: storage,
|
||||
}
|
||||
if err := s.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) initialize(ctx context.Context) error {
|
||||
var key *jose.SigningKey
|
||||
var err error
|
||||
key, err = s.storage.GetSigningKey(ctx)
|
||||
if err != nil {
|
||||
key, err = s.storage.SaveKeyPair(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.algorithm = key.Algorithm
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) {
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Sign(payload)
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) Sign(payload []byte) (string, error) {
|
||||
result, err := s.signer.Sign(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.CompactSerialize()
|
||||
}
|
||||
|
||||
func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return s.algorithm
|
||||
}
|
95
pkg/op/signer_test.go
Normal file
95
pkg/op/signer_test.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// func TestNewDefaultSigner(t *testing.T) {
|
||||
// type args struct {
|
||||
// storage Storage
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// args args
|
||||
// want Signer
|
||||
// wantErr bool
|
||||
// }{
|
||||
// {
|
||||
// "err initialize storage fails",
|
||||
// args{mock.NewMockStorageSigningKeyError(t)},
|
||||
// nil,
|
||||
// true,
|
||||
// },
|
||||
// {
|
||||
// "err initialize storage fails",
|
||||
// args{mock.NewMockStorageSigningKeyInvalid(t)},
|
||||
// nil,
|
||||
// true,
|
||||
// },
|
||||
// {
|
||||
// "initialize ok",
|
||||
// args{mock.NewMockStorageSigningKey(t)},
|
||||
// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)},
|
||||
// false,
|
||||
// },
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// got, err := op.NewDefaultSigner(tt.args.storage)
|
||||
// if (err != nil) != tt.wantErr {
|
||||
// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
// return
|
||||
// }
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func Test_idTokenSigner_Sign(t *testing.T) {
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
signer jose.Signer
|
||||
storage Storage
|
||||
}
|
||||
type args struct {
|
||||
payload []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok",
|
||||
fields{signer, nil},
|
||||
args{[]byte("test")},
|
||||
"eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww",
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &idTokenSigner{
|
||||
signer: tt.fields.signer,
|
||||
storage: tt.fields.storage,
|
||||
}
|
||||
got, err := s.Sign(tt.args.payload)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
48
pkg/op/storage.go
Normal file
48
pkg/op/storage.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type AuthStorage interface {
|
||||
CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error)
|
||||
AuthRequestByID(context.Context, string) (AuthRequest, error)
|
||||
DeleteAuthRequest(context.Context, string) error
|
||||
|
||||
GetSigningKey(context.Context) (*jose.SigningKey, error)
|
||||
GetKeySet(context.Context) (*jose.JSONWebKeySet, error)
|
||||
SaveKeyPair(context.Context) (*jose.SigningKey, error)
|
||||
}
|
||||
|
||||
type OPStorage interface {
|
||||
GetClientByClientID(context.Context, string) (Client, error)
|
||||
AuthorizeClientIDSecret(context.Context, string, string) error
|
||||
GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error)
|
||||
}
|
||||
|
||||
type Storage interface {
|
||||
AuthStorage
|
||||
OPStorage
|
||||
}
|
||||
|
||||
type AuthRequest interface {
|
||||
GetID() string
|
||||
GetACR() string
|
||||
GetAMR() []string
|
||||
GetAudience() []string
|
||||
GetAuthTime() time.Time
|
||||
GetClientID() string
|
||||
GetCodeChallenge() *oidc.CodeChallenge
|
||||
GetNonce() string
|
||||
GetRedirectURI() string
|
||||
GetResponseType() oidc.ResponseType
|
||||
GetScopes() []string
|
||||
GetState() string
|
||||
GetSubject() string
|
||||
Done() bool
|
||||
}
|
93
pkg/op/token.go
Normal file
93
pkg/op/token.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
type TokenCreator interface {
|
||||
Issuer() string
|
||||
Signer() Signer
|
||||
Storage() Storage
|
||||
Crypto() Crypto
|
||||
}
|
||||
|
||||
func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) {
|
||||
var accessToken string
|
||||
if createAccessToken {
|
||||
var err error
|
||||
accessToken, err = CreateAccessToken(authReq, client, creator)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exp := uint64(client.AccessTokenLifetime().Seconds())
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: accessToken,
|
||||
IDToken: idToken,
|
||||
TokenType: oidc.BearerToken,
|
||||
ExpiresIn: exp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) {
|
||||
if client.AccessTokenType() == AccessTokenTypeJWT {
|
||||
return CreateJWT(creator.Issuer(), authReq, client, creator.Signer())
|
||||
}
|
||||
return CreateBearerToken(authReq, creator.Crypto())
|
||||
}
|
||||
|
||||
func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) {
|
||||
return crypto.Encrypt(authReq.GetID())
|
||||
}
|
||||
|
||||
func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) {
|
||||
now := time.Now().UTC()
|
||||
nbf := now
|
||||
exp := now.Add(client.AccessTokenLifetime())
|
||||
claims := &oidc.AccessTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: now,
|
||||
NotBefore: nbf,
|
||||
}
|
||||
return signer.SignAccessToken(claims)
|
||||
}
|
||||
|
||||
func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) {
|
||||
var err error
|
||||
exp := time.Now().UTC().Add(validity)
|
||||
claims := &oidc.IDTokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: authReq.GetSubject(),
|
||||
Audiences: authReq.GetAudience(),
|
||||
Expiration: exp,
|
||||
IssuedAt: time.Now().UTC(),
|
||||
AuthTime: authReq.GetAuthTime(),
|
||||
Nonce: authReq.GetNonce(),
|
||||
AuthenticationContextClassReference: authReq.GetACR(),
|
||||
AuthenticationMethodsReferences: authReq.GetAMR(),
|
||||
AuthorizedParty: authReq.GetClientID(),
|
||||
}
|
||||
if accessToken != "" {
|
||||
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if code != "" {
|
||||
claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return signer.SignIDToken(claims)
|
||||
}
|
151
pkg/op/tokenrequest.go
Normal file
151
pkg/op/tokenrequest.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type Exchanger interface {
|
||||
Issuer() string
|
||||
Storage() Storage
|
||||
Decoder() *schema.Decoder
|
||||
Signer() Signer
|
||||
Crypto() Crypto
|
||||
AuthMethodPostSupported() bool
|
||||
}
|
||||
|
||||
func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
}
|
||||
if tokenReq.Code == "" {
|
||||
ExchangeRequestError(w, r, ErrInvalidRequest("code missing"))
|
||||
return
|
||||
}
|
||||
authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
utils.MarshalJSON(w, resp)
|
||||
}
|
||||
|
||||
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("error parsing form")
|
||||
}
|
||||
tokenReq := new(oidc.AccessTokenRequest)
|
||||
err = decoder.Decode(tokenReq, r.Form)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("error decoding form")
|
||||
}
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
tokenReq.ClientID = clientID
|
||||
tokenReq.ClientSecret = clientSecret
|
||||
|
||||
}
|
||||
return tokenReq, nil
|
||||
}
|
||||
|
||||
func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.GetID() != authReq.GetClientID() {
|
||||
return nil, nil, ErrInvalidRequest("invalid auth code")
|
||||
}
|
||||
if tokenReq.RedirectURI != authReq.GetRedirectURI() {
|
||||
return nil, nil, ErrInvalidRequest("redirect_uri does no correspond")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) {
|
||||
client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodNone {
|
||||
authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger)
|
||||
return authReq, client, err
|
||||
}
|
||||
if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() {
|
||||
return nil, nil, errors.New("basic not supported")
|
||||
}
|
||||
err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
return authReq, client, nil
|
||||
}
|
||||
|
||||
func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error {
|
||||
return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret)
|
||||
}
|
||||
|
||||
func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) {
|
||||
if tokenReq.CodeVerifier == "" {
|
||||
return nil, ErrInvalidRequest("code_challenge required")
|
||||
}
|
||||
authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage())
|
||||
if err != nil {
|
||||
return nil, ErrInvalidRequest("invalid code")
|
||||
}
|
||||
if !oidc.VerifyCodeChallenge(authReq.GetCodeChallenge(), tokenReq.CodeVerifier) {
|
||||
return nil, ErrInvalidRequest("code_challenge invalid")
|
||||
}
|
||||
return authReq, nil
|
||||
}
|
||||
|
||||
func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) {
|
||||
id, err := crypto.Decrypt(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return storage.AuthRequestByID(ctx, id)
|
||||
}
|
||||
|
||||
func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
|
||||
tokenRequest, err := ParseTokenExchangeRequest(w, r)
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage())
|
||||
if err != nil {
|
||||
ExchangeRequestError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) {
|
||||
return nil, errors.New("Unimplemented") //TODO: impl
|
||||
}
|
||||
|
||||
func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error {
|
||||
return errors.New("Unimplemented") //TODO: impl
|
||||
}
|
28
pkg/op/userinfo.go
Normal file
28
pkg/op/userinfo.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package op
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
type UserinfoProvider interface {
|
||||
Storage() Storage
|
||||
}
|
||||
|
||||
func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) {
|
||||
scopes, err := ScopesFromAccessToken(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes)
|
||||
if err != nil {
|
||||
utils.MarshalJSON(w, err)
|
||||
return
|
||||
}
|
||||
utils.MarshalJSON(w, info)
|
||||
}
|
||||
|
||||
func ScopesFromAccessToken(w http.ResponseWriter, r *http.Request) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
287
pkg/rp/default_rp.go
Normal file
287
pkg/rp/default_rp.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc/grants"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
grants_tx "github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
idTokenKey = "id_token"
|
||||
stateParam = "state"
|
||||
pkceCode = "pkce"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) {
|
||||
http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError)
|
||||
}
|
||||
)
|
||||
|
||||
//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface
|
||||
type DefaultRP struct {
|
||||
endpoints Endpoints
|
||||
|
||||
oauthConfig oauth2.Config
|
||||
config *Config
|
||||
pkce bool
|
||||
|
||||
httpClient *http.Client
|
||||
cookieHandler *utils.CookieHandler
|
||||
|
||||
errorHandler func(http.ResponseWriter, *http.Request, string, string, string)
|
||||
|
||||
verifier Verifier
|
||||
}
|
||||
|
||||
//NewDefaultRP creates `DefaultRP` with the given
|
||||
//Config and possible configOptions
|
||||
//it will run discovery on the provided issuer
|
||||
//if no verifier is provided using the options the `DefaultVerifier` is set
|
||||
func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExchangeRP, error) {
|
||||
p := &DefaultRP{
|
||||
config: rpConfig,
|
||||
httpClient: utils.DefaultHTTPClient,
|
||||
}
|
||||
|
||||
for _, optFunc := range rpOpts {
|
||||
optFunc(p)
|
||||
}
|
||||
|
||||
if err := p.discover(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.errorHandler == nil {
|
||||
p.errorHandler = DefaultErrorHandler
|
||||
}
|
||||
|
||||
if p.verifier == nil {
|
||||
p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
//DefaultRPOpts is the type for providing dynamic options to the DefaultRP
|
||||
type DefaultRPOpts func(p *DefaultRP)
|
||||
|
||||
//WithCookieHandler set a `CookieHandler` for securing the various redirects
|
||||
func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.cookieHandler = cookieHandler
|
||||
}
|
||||
}
|
||||
|
||||
//WithPKCE sets the RP to use PKCE (oauth2 code challenge)
|
||||
//it also sets a `CookieHandler` for securing the various redirects
|
||||
//and exchanging the code challenge
|
||||
func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.pkce = true
|
||||
p.cookieHandler = cookieHandler
|
||||
}
|
||||
}
|
||||
|
||||
//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier
|
||||
func WithHTTPClient(client *http.Client) DefaultRPOpts {
|
||||
return func(p *DefaultRP) {
|
||||
p.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//wrapping the oauth2 `AuthCodeURL`
|
||||
//returning the url of the auth request
|
||||
func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string {
|
||||
authOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
authOpts = append(authOpts, opt()...)
|
||||
}
|
||||
return p.oauthConfig.AuthCodeURL(state, authOpts...)
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//extending the `AuthURL` method with a http redirect handler
|
||||
func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
opts := make([]AuthURLOpt, 0)
|
||||
if err := p.trySetStateCookie(w, state); err != nil {
|
||||
http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if p.pkce {
|
||||
codeChallenge, err := p.generateAndStoreCodeChallenge(w)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
opts = append(opts, WithCodeChallenge(codeChallenge))
|
||||
}
|
||||
http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) {
|
||||
var codeVerifier string
|
||||
codeVerifier = "s"
|
||||
if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return oidc.NewSHACodeChallenge(codeVerifier), nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//handling the oauth2 code exchange, extracting and validating the id_token
|
||||
//returning it paresed together with the oauth2 tokens (access, refresh)
|
||||
func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient)
|
||||
codeOpts := make([]oauth2.AuthCodeOption, 0)
|
||||
for _, opt := range opts {
|
||||
codeOpts = append(codeOpts, opt()...)
|
||||
}
|
||||
|
||||
token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...)
|
||||
if err != nil {
|
||||
return nil, err //TODO: our error
|
||||
}
|
||||
idTokenString, ok := token.Extra(idTokenKey).(string)
|
||||
if !ok {
|
||||
//TODO: implement
|
||||
}
|
||||
|
||||
idToken, err := p.verifier.Verify(ctx, token.AccessToken, idTokenString)
|
||||
if err != nil {
|
||||
return nil, err //TODO: err
|
||||
}
|
||||
|
||||
return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil
|
||||
}
|
||||
|
||||
//AuthURL is the `RelayingParty` interface implementation
|
||||
//extending the `CodeExchange` method with callback function
|
||||
func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := p.tryReadStateCookie(w, r)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
params := r.URL.Query()
|
||||
if params.Get("error") != "" {
|
||||
p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state)
|
||||
return
|
||||
}
|
||||
codeOpts := make([]CodeExchangeOpt, 0)
|
||||
if p.pkce {
|
||||
codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier))
|
||||
}
|
||||
tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts...)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
callback(w, r, tokens, state)
|
||||
}
|
||||
}
|
||||
|
||||
// func (p *DefaultRP) Introspect(ctx context.Context, accessToken string) (oidc.TokenIntrospectResponse, error) {
|
||||
// // req := &http.Request{}
|
||||
// // resp, err := p.httpClient.Do(req)
|
||||
// // if err != nil {
|
||||
|
||||
// // }
|
||||
// // p.endpoints.IntrospectURL
|
||||
// return nil, nil
|
||||
// }
|
||||
|
||||
func (p *DefaultRP) Userinfo() {}
|
||||
|
||||
//ClientCredentials is the `RelayingParty` interface implementation
|
||||
//handling the oauth2 client credentials grant
|
||||
func (p *DefaultRP) ClientCredentials(ctx context.Context, scopes ...string) (newToken *oauth2.Token, err error) {
|
||||
return p.callTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...))
|
||||
}
|
||||
|
||||
//TokenExchange is the `TokenExchangeRP` interface implementation
|
||||
//handling the oauth2 token exchange (draft)
|
||||
func (p *DefaultRP) TokenExchange(ctx context.Context, request *grants_tx.TokenExchangeRequest) (newToken *oauth2.Token, err error) {
|
||||
return p.callTokenEndpoint(request)
|
||||
}
|
||||
|
||||
//DelegationTokenExchange is the `TokenExchangeRP` interface implementation
|
||||
//handling the oauth2 token exchange for a delegation token (draft)
|
||||
func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken string, reqOpts ...grants_tx.TokenExchangeOption) (newToken *oauth2.Token, err error) {
|
||||
return p.TokenExchange(ctx, DelegationTokenRequest(subjectToken, reqOpts...))
|
||||
}
|
||||
|
||||
func (p *DefaultRP) discover() error {
|
||||
wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint
|
||||
req, err := http.NewRequest("GET", wellKnown, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
discoveryConfig := new(oidc.DiscoveryConfiguration)
|
||||
err = utils.HttpRequest(p.httpClient, req, &discoveryConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.endpoints = GetEndpoints(discoveryConfig)
|
||||
p.oauthConfig = oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
Endpoint: p.endpoints.Endpoint,
|
||||
RedirectURL: p.config.CallbackURL,
|
||||
Scopes: p.config.Scopes,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Token, err error) {
|
||||
req, err := utils.FormRequest(p.endpoints.TokenURL, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(p.config.ClientID + ":" + p.config.ClientSecret))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
token := new(oauth2.Token)
|
||||
if err := utils.HttpRequest(p.httpClient, req, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error {
|
||||
if p.cookieHandler != nil {
|
||||
if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *DefaultRP) tryReadStateCookie(w http.ResponseWriter, r *http.Request) (state string, err error) {
|
||||
if p.cookieHandler == nil {
|
||||
return r.FormValue(stateParam), nil
|
||||
}
|
||||
state, err = p.cookieHandler.CheckQueryCookie(r, stateParam)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
p.cookieHandler.DeleteCookie(w, stateParam)
|
||||
return state, nil
|
||||
}
|
363
pkg/rp/default_verifier.go
Normal file
363
pkg/rp/default_verifier.go
Normal file
|
@ -0,0 +1,363 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
)
|
||||
|
||||
//DefaultVerifier implements the `Verifier` interface
|
||||
type DefaultVerifier struct {
|
||||
config *verifierConfig
|
||||
keySet oidc.KeySet
|
||||
}
|
||||
|
||||
//ConfFunc is the type for providing dynamic options to the DefaultVerfifier
|
||||
type ConfFunc func(*verifierConfig)
|
||||
|
||||
//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
|
||||
type ACRVerifier func(string) error
|
||||
|
||||
//NewDefaultVerifier creates `DefaultVerifier` with the given
|
||||
//issuer, clientID, keyset and possible configOptions
|
||||
func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier {
|
||||
conf := &verifierConfig{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
iat: &iatConfig{
|
||||
// offset: time.Duration(500 * time.Millisecond),
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range confOpts {
|
||||
if opt != nil {
|
||||
opt(conf)
|
||||
}
|
||||
}
|
||||
return &DefaultVerifier{config: conf, keySet: keySet}
|
||||
}
|
||||
|
||||
//WithIgnoreIssuedAt will turn off iat claim verification
|
||||
func WithIgnoreIssuedAt() func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.ignore = true
|
||||
}
|
||||
}
|
||||
|
||||
//WithIssuedAtOffset mitigates the risk of iat to be in the future
|
||||
//because of clock skews with the ability to add an offset to the current time
|
||||
func WithIssuedAtOffset(offset time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
//WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now
|
||||
func WithIssuedAtMaxAge(maxAge time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.iat.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
//WithNonce TODO: ?
|
||||
func WithNonce(nonce string) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.nonce = nonce
|
||||
}
|
||||
}
|
||||
|
||||
//WithACRVerifier sets the verifier for the acr claim
|
||||
func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.acr = verifier
|
||||
}
|
||||
}
|
||||
|
||||
//WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now
|
||||
func WithAuthTimeMaxAge(maxAge time.Duration) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.maxAge = maxAge
|
||||
}
|
||||
}
|
||||
|
||||
//WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm
|
||||
func WithSupportedSigningAlgorithms(algs ...string) func(*verifierConfig) {
|
||||
return func(conf *verifierConfig) {
|
||||
conf.supportedSignAlgs = algs
|
||||
}
|
||||
}
|
||||
|
||||
type verifierConfig struct {
|
||||
issuer string
|
||||
clientID string
|
||||
nonce string
|
||||
iat *iatConfig
|
||||
acr ACRVerifier
|
||||
maxAge time.Duration
|
||||
supportedSignAlgs []string
|
||||
|
||||
// httpClient *http.Client
|
||||
|
||||
now time.Time
|
||||
}
|
||||
|
||||
type iatConfig struct {
|
||||
ignore bool
|
||||
offset time.Duration
|
||||
maxAge time.Duration
|
||||
}
|
||||
|
||||
//DefaultACRVerifier implements `ACRVerifier` returning an error
|
||||
//if non of the provided values matches the acr claim
|
||||
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
|
||||
return func(acr string) error {
|
||||
if !utils.Contains(possibleValues, acr) {
|
||||
return ErrAcrInvalid(possibleValues, acr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
//Verify implements the `Verify` method of the `Verifier` interface
|
||||
//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation
|
||||
func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) {
|
||||
v.config.now = time.Now().UTC()
|
||||
idToken, err := v.VerifyIDToken(ctx, idTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token
|
||||
return nil, err
|
||||
}
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) now() time.Time {
|
||||
if v.config.now.IsZero() {
|
||||
v.config.now = time.Now().UTC().Round(time.Second)
|
||||
}
|
||||
return v.config.now
|
||||
}
|
||||
|
||||
//VerifyIDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
|
||||
func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) {
|
||||
//1. if encrypted --> decrypt
|
||||
decrypted, err := v.decryptToken(idTokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, payload, err := v.parseToken(decrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
//2, check issuer (exact match)
|
||||
if err := v.checkIssuer(claims.Issuer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//3. check aud (aud must contain client_id, all aud strings must be allowed)
|
||||
if err = v.checkAudience(claims.Audiences); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//6. check signature by keys
|
||||
//7. check alg default is rs256
|
||||
//8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret
|
||||
claims.Signature, err = v.checkSignature(ctx, decrypted, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//9. check exp before now
|
||||
if err = v.checkExpiration(claims.Expiration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//10. check iat duration is optional (can be checked)
|
||||
if err = v.checkIssuedAt(claims.IssuedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//11. check nonce (check if optional possible) id_token.nonce == sentNonce
|
||||
if err = v.checkNonce(claims.Nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//12. if acr requested check acr
|
||||
if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//13. if auth_time requested check if auth_time is less than max age
|
||||
if err = v.checkAuthTime(claims.AuthTime); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed)
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
idToken := new(oidc.IDTokenClaims)
|
||||
err = json.Unmarshal(payload, idToken)
|
||||
return idToken, payload, err
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkIssuer(issuer string) error {
|
||||
if v.config.issuer != issuer {
|
||||
return ErrIssuerInvalid(v.config.issuer, issuer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkAudience(audiences []string) error {
|
||||
if !utils.Contains(audiences, v.config.clientID) {
|
||||
return ErrAudienceMissingClientID(v.config.clientID)
|
||||
}
|
||||
|
||||
//TODO: check aud trusted
|
||||
return nil
|
||||
}
|
||||
|
||||
//4. if multiple aud strings --> check if azp
|
||||
//5. if azp --> check azp == client_id
|
||||
func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error {
|
||||
if len(audiences) > 1 {
|
||||
if authorizedParty == "" {
|
||||
return ErrAzpMissing()
|
||||
}
|
||||
}
|
||||
if authorizedParty != "" && authorizedParty != v.config.clientID {
|
||||
return ErrAzpInvalid(authorizedParty, v.config.clientID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) {
|
||||
jws, err := jose.ParseSigned(idTokenString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
return "", nil //TODO: error
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
return "", nil //TODO: error
|
||||
}
|
||||
sig := jws.Signatures[0]
|
||||
supportedSigAlgs := v.config.supportedSignAlgs
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{"RS256"}
|
||||
}
|
||||
if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) {
|
||||
return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
|
||||
}
|
||||
|
||||
signedPayload, err := v.keySet.VerifySignature(ctx, jws)
|
||||
if err != nil {
|
||||
return "", err
|
||||
//TODO:
|
||||
}
|
||||
|
||||
if !bytes.Equal(signedPayload, payload) {
|
||||
return "", ErrSignatureInvalidPayload() //TODO: err
|
||||
}
|
||||
return jose.SignatureAlgorithm(sig.Header.Algorithm), nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkExpiration(expiration time.Time) error {
|
||||
expiration = expiration.Round(time.Second)
|
||||
if !v.now().Before(expiration) {
|
||||
return ErrExpInvalid(expiration)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error {
|
||||
if v.config.iat.ignore {
|
||||
return nil
|
||||
}
|
||||
issuedAt = issuedAt.Round(time.Second)
|
||||
offset := v.now().Add(v.config.iat.offset).Round(time.Second)
|
||||
if issuedAt.After(offset) {
|
||||
return ErrIatInFuture(issuedAt, offset)
|
||||
}
|
||||
if v.config.iat.maxAge == 0 {
|
||||
return nil
|
||||
}
|
||||
maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second)
|
||||
if issuedAt.Before(maxAge) {
|
||||
return ErrIatToOld(maxAge, issuedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkNonce(nonce string) error {
|
||||
if v.config.nonce == "" {
|
||||
return nil
|
||||
}
|
||||
if v.config.nonce != nonce {
|
||||
return ErrNonceInvalid(v.config.nonce, nonce)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error {
|
||||
if v.config.acr != nil {
|
||||
return v.config.acr(acr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error {
|
||||
if v.config.maxAge == 0 {
|
||||
return nil
|
||||
}
|
||||
if authTime.IsZero() {
|
||||
return ErrAuthTimeNotPresent()
|
||||
}
|
||||
authTime = authTime.Round(time.Second)
|
||||
maxAge := v.now().Add(-v.config.maxAge).Round(time.Second)
|
||||
if authTime.Before(maxAge) {
|
||||
return ErrAuthTimeToOld(maxAge, authTime)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) {
|
||||
return tokenString, nil //TODO: impl
|
||||
}
|
||||
|
||||
func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error {
|
||||
if atHash == "" {
|
||||
return nil //TODO: return error
|
||||
}
|
||||
|
||||
actual, err := oidc.ClaimHash(accessToken, sigAlgorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if actual != atHash {
|
||||
return nil //TODO: error
|
||||
}
|
||||
return nil
|
||||
}
|
13
pkg/rp/delegation.go
Normal file
13
pkg/rp/delegation.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
//DelegationTokenRequest is an implementation of TokenExchangeRequest
|
||||
//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional
|
||||
//"urn:ietf:params:oauth:token-type:access_token" actor token for a
|
||||
//"urn:ietf:params:oauth:token-type:access_token" delegation token
|
||||
func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest {
|
||||
return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...)
|
||||
}
|
58
pkg/rp/error.go
Normal file
58
pkg/rp/error.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIssuerInvalid = func(expected, actual string) *validationError {
|
||||
return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual)
|
||||
}
|
||||
ErrAudienceMissingClientID = func(clientID string) *validationError {
|
||||
return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID)
|
||||
}
|
||||
ErrAzpMissing = func() *validationError {
|
||||
return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty")
|
||||
}
|
||||
ErrAzpInvalid = func(azp, clientID string) *validationError {
|
||||
return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID)
|
||||
}
|
||||
ErrExpInvalid = func(exp time.Time) *validationError {
|
||||
return ValidationError("Token has expired %v", exp)
|
||||
}
|
||||
ErrIatInFuture = func(exp, now time.Time) *validationError {
|
||||
return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now)
|
||||
}
|
||||
ErrIatToOld = func(maxAge, iat time.Time) *validationError {
|
||||
return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat))
|
||||
}
|
||||
ErrNonceInvalid = func(expected, actual string) *validationError {
|
||||
return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual)
|
||||
}
|
||||
ErrAcrInvalid = func(expected []string, actual string) *validationError {
|
||||
return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual)
|
||||
}
|
||||
|
||||
ErrAuthTimeNotPresent = func() *validationError {
|
||||
return ValidationError("claim `auth_time` of token is missing")
|
||||
}
|
||||
ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError {
|
||||
return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime))
|
||||
}
|
||||
ErrSignatureInvalidPayload = func() *validationError {
|
||||
return ValidationError("Signature does not match Payload")
|
||||
}
|
||||
)
|
||||
|
||||
func ValidationError(message string, args ...interface{}) *validationError {
|
||||
return &validationError{fmt.Sprintf(message, args...)} //TODO: impl
|
||||
}
|
||||
|
||||
type validationError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (v *validationError) Error() string {
|
||||
return v.message
|
||||
}
|
166
pkg/rp/jwks.go
Normal file
166
pkg/rp/jwks.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/caos/oidc/pkg/utils"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet {
|
||||
return &remoteKeySet{httpClient: client, jwksURL: jwksURL}
|
||||
}
|
||||
|
||||
type remoteKeySet struct {
|
||||
jwksURL string
|
||||
httpClient *http.Client
|
||||
|
||||
// guard all other fields
|
||||
mu sync.Mutex
|
||||
|
||||
// inflight suppresses parallel execution of updateKeys and allows
|
||||
// multiple goroutines to wait for its result.
|
||||
inflight *inflight
|
||||
|
||||
// A set of cached keys and their expiry.
|
||||
cachedKeys []jose.JSONWebKey
|
||||
}
|
||||
|
||||
// inflight is used to wait on some in-flight request from multiple goroutines.
|
||||
type inflight struct {
|
||||
doneCh chan struct{}
|
||||
|
||||
keys []jose.JSONWebKey
|
||||
err error
|
||||
}
|
||||
|
||||
func newInflight() *inflight {
|
||||
return &inflight{doneCh: make(chan struct{})}
|
||||
}
|
||||
|
||||
// wait returns a channel that multiple goroutines can receive on. Once it returns
|
||||
// a value, the inflight request is done and result() can be inspected.
|
||||
func (i *inflight) wait() <-chan struct{} {
|
||||
return i.doneCh
|
||||
}
|
||||
|
||||
// done can only be called by a single goroutine. It records the result of the
|
||||
// inflight request and signals other goroutines that the result is safe to
|
||||
// inspect.
|
||||
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
|
||||
i.keys = keys
|
||||
i.err = err
|
||||
close(i.doneCh)
|
||||
}
|
||||
|
||||
// result cannot be called until the wait() channel has returned a value.
|
||||
func (i *inflight) result() ([]jose.JSONWebKey, error) {
|
||||
return i.keys, i.err
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
|
||||
// We don't support JWTs signed with multiple signatures.
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
}
|
||||
|
||||
keys := r.keysFromCache()
|
||||
payload, err, ok := checkKey(keyID, keys, jws)
|
||||
if ok {
|
||||
return payload, err
|
||||
}
|
||||
|
||||
keys, err = r.keysFromRemote(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching keys %v", err)
|
||||
}
|
||||
|
||||
payload, err, ok = checkKey(keyID, keys, jws)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid kid")
|
||||
}
|
||||
return payload, err
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cachedKeys
|
||||
}
|
||||
|
||||
// keysFromRemote syncs the key set from the remote set, records the values in the
|
||||
// cache, and returns the key set.
|
||||
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||
// Need to lock to inspect the inflight request field.
|
||||
r.mu.Lock()
|
||||
// If there's not a current inflight request, create one.
|
||||
if r.inflight == nil {
|
||||
r.inflight = newInflight()
|
||||
|
||||
// This goroutine has exclusive ownership over the current inflight
|
||||
// request. It releases the resource by nil'ing the inflight field
|
||||
// once the goroutine is done.
|
||||
go r.updateKeys(ctx)
|
||||
}
|
||||
inflight := r.inflight
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-inflight.wait():
|
||||
return inflight.result()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) updateKeys(ctx context.Context) {
|
||||
// Sync keys and finish inflight when that's done.
|
||||
keys, err := r.fetchRemoteKeys(ctx)
|
||||
|
||||
r.inflight.done(keys, err)
|
||||
|
||||
// Lock to update the keys and indicate that there is no longer an
|
||||
// inflight request.
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if err == nil {
|
||||
r.cachedKeys = keys
|
||||
}
|
||||
|
||||
// Free inflight so a different request can run.
|
||||
r.inflight = nil
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
|
||||
req, err := http.NewRequest("GET", r.jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: can't create request: %v", err)
|
||||
}
|
||||
|
||||
keySet := new(jose.JSONWebKeySet)
|
||||
if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil {
|
||||
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
|
||||
}
|
||||
|
||||
return keySet.Keys, nil
|
||||
}
|
||||
|
||||
func checkKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) {
|
||||
for _, key := range keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
payload, err := jws.Verify(&key)
|
||||
return payload, err, true
|
||||
}
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
105
pkg/rp/relaying_party.go
Normal file
105
pkg/rp/relaying_party.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
//RelayingParty declares the minimal interface for oidc clients
|
||||
type RelayingParty interface {
|
||||
|
||||
//AuthURL returns the authorization endpoint with a given state
|
||||
AuthURL(state string, opts ...AuthURLOpt) string
|
||||
|
||||
//AuthURLHandler should implement the AuthURL func as http.HandlerFunc
|
||||
//(redirecting to the auth endpoint)
|
||||
AuthURLHandler(state string) http.HandlerFunc
|
||||
|
||||
//CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant)
|
||||
//returning an `Access Token` and `ID Token Claims`
|
||||
CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error)
|
||||
|
||||
//CodeExchangeHandler extends the CodeExchange func,
|
||||
//calling the provided callback func on success with additional returned `state`
|
||||
CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc
|
||||
|
||||
//ClientCredentials implements the oauth2 Client Credentials Grant
|
||||
//requesting an `Access Token` for the client itself, without user context
|
||||
ClientCredentials(ctx context.Context, scopes ...string) (*oauth2.Token, error)
|
||||
|
||||
//Introspects calls the Introspect Endpoint
|
||||
//for validating an (access) token
|
||||
// Introspect(ctx context.Context, token string) (TokenIntrospectResponse, error)
|
||||
|
||||
//Userinfo implements the OIDC Userinfo call
|
||||
//returning the info of the user for the requested scopes of an access token
|
||||
Userinfo()
|
||||
}
|
||||
|
||||
//PasswortGrantRP extends the `RelayingParty` interface with the oauth2 `Password Grant`
|
||||
//
|
||||
//This interface is separated from the standard `RelayingParty` interface as the `password grant`
|
||||
//is part of the oauth2 and therefore OIDC specification, but should only be used when there's no
|
||||
//other possibility, so IMHO never ever. Ever.
|
||||
type PasswortGrantRP interface {
|
||||
RelayingParty
|
||||
|
||||
//PasswordGrant implements the oauth2 `Password Grant`,
|
||||
//requesting an access token with the users `username` and `password`
|
||||
PasswordGrant(context.Context, string, string) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
CallbackURL string
|
||||
Issuer string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type OptionFunc func(RelayingParty)
|
||||
|
||||
type Endpoints struct {
|
||||
oauth2.Endpoint
|
||||
IntrospectURL string
|
||||
UserinfoURL string
|
||||
JKWsURL string
|
||||
}
|
||||
|
||||
func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints {
|
||||
return Endpoints{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: discoveryConfig.AuthorizationEndpoint,
|
||||
AuthStyle: oauth2.AuthStyleAutoDetect,
|
||||
TokenURL: discoveryConfig.TokenEndpoint,
|
||||
},
|
||||
IntrospectURL: discoveryConfig.IntrospectionEndpoint,
|
||||
UserinfoURL: discoveryConfig.UserinfoEndpoint,
|
||||
JKWsURL: discoveryConfig.JwksURI,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthURLOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
//WithCodeChallenge sets the `code_challenge` params in the auth request
|
||||
func WithCodeChallenge(codeChallenge string) AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CodeExchangeOpt func() []oauth2.AuthCodeOption
|
||||
|
||||
//WithCodeVerifier sets the `code_verifier` param in the token request
|
||||
func WithCodeVerifier(codeVerifier string) CodeExchangeOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)}
|
||||
}
|
||||
}
|
27
pkg/rp/tockenexchange.go
Normal file
27
pkg/rp/tockenexchange.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc/grants/tokenexchange"
|
||||
)
|
||||
|
||||
//TokenExchangeRP extends the `RelayingParty` interface for the *draft* oauth2 `Token Exchange`
|
||||
type TokenExchangeRP interface {
|
||||
RelayingParty
|
||||
|
||||
//TokenExchange implement the `Token Echange Grant` exchanging some token for an other
|
||||
TokenExchange(context.Context, *tokenexchange.TokenExchangeRequest) (*oauth2.Token, error)
|
||||
}
|
||||
|
||||
//DelegationTokenExchangeRP extends the `TokenExchangeRP` interface
|
||||
//for the specific `delegation token` request
|
||||
type DelegationTokenExchangeRP interface {
|
||||
TokenExchangeRP
|
||||
|
||||
//DelegationTokenExchange implement the `Token Exchange Grant`
|
||||
//providing an access token in request for a `delegation` token for a given resource / audience
|
||||
DelegationTokenExchange(context.Context, string, ...tokenexchange.TokenExchangeOption) (*oauth2.Token, error)
|
||||
}
|
15
pkg/rp/verifier.go
Normal file
15
pkg/rp/verifier.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package rp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/oidc/pkg/oidc"
|
||||
)
|
||||
|
||||
//Verifier implement the Token Response Validation as defined in OIDC specification
|
||||
//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation
|
||||
type Verifier interface {
|
||||
|
||||
//Verify checks the access_token and id_token and returns the `id token claims`
|
||||
Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error)
|
||||
}
|
110
pkg/utils/cookie.go
Normal file
110
pkg/utils/cookie.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
)
|
||||
|
||||
type CookieHandler struct {
|
||||
securecookie *securecookie.SecureCookie
|
||||
secureOnly bool
|
||||
sameSite http.SameSite
|
||||
maxAge int
|
||||
domain string
|
||||
}
|
||||
|
||||
func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *CookieHandler {
|
||||
c := &CookieHandler{
|
||||
securecookie: securecookie.New(hashKey, encryptKey),
|
||||
secureOnly: true,
|
||||
sameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type CookieHandlerOpt func(*CookieHandler)
|
||||
|
||||
func WithUnsecure() CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.secureOnly = false
|
||||
}
|
||||
}
|
||||
|
||||
func WithSameSite(sameSite http.SameSite) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.sameSite = sameSite
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxAge(maxAge int) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.maxAge = maxAge
|
||||
c.securecookie.MaxAge(maxAge)
|
||||
}
|
||||
}
|
||||
|
||||
func WithDomain(domain string) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.domain = domain
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
|
||||
cookie, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var value string
|
||||
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
|
||||
value, err := c.CheckCookie(r, name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if value != r.FormValue(name) {
|
||||
return "", errors.New(name + " does not compare")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
|
||||
encoded, err := c.securecookie.Encode(name, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: encoded,
|
||||
Domain: c.domain,
|
||||
Path: "/",
|
||||
MaxAge: c.maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: c.secureOnly,
|
||||
SameSite: c.sameSite,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Domain: c.domain,
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: c.secureOnly,
|
||||
SameSite: c.sameSite,
|
||||
})
|
||||
}
|
70
pkg/utils/crypto.go
Normal file
70
pkg/utils/crypto.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
func EncryptAES(data string, key string) (string, error) {
|
||||
encrypted, err := EncryptBytesAES([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func EncryptBytesAES(plainText []byte, key string) ([]byte, error) {
|
||||
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cipherText := make([]byte, aes.BlockSize+len(plainText))
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)
|
||||
|
||||
return cipherText, nil
|
||||
}
|
||||
|
||||
func DecryptAES(data string, key string) (string, error) {
|
||||
text, err := base64.URLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
decrypted, err := DecryptBytesAES(text, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decrypted), nil
|
||||
}
|
||||
|
||||
func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) {
|
||||
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
err = errors.New("Ciphertext block size is too short!")
|
||||
return nil, err
|
||||
}
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
cipherText = cipherText[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText, cipherText)
|
||||
|
||||
return cipherText, err
|
||||
}
|
30
pkg/utils/hash.go
Normal file
30
pkg/utils/hash.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) {
|
||||
switch sigAlgorithm {
|
||||
case jose.RS256, jose.ES256, jose.PS256:
|
||||
return sha256.New(), nil
|
||||
case jose.RS384, jose.ES384, jose.PS384:
|
||||
return sha512.New384(), nil
|
||||
case jose.RS512, jose.ES512, jose.PS512:
|
||||
return sha512.New(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm)
|
||||
}
|
||||
}
|
||||
|
||||
func HashString(hash hash.Hash, s string) string {
|
||||
hash.Write([]byte(s)) // hash documents that Write will never return an error
|
||||
sum := hash.Sum(nil)[:hash.Size()/2]
|
||||
return base64.RawURLEncoding.EncodeToString(sum)
|
||||
}
|
67
pkg/utils/http.go
Normal file
67
pkg/utils/http.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultHTTPClient = &http.Client{
|
||||
Timeout: time.Duration(30 * time.Second),
|
||||
}
|
||||
)
|
||||
|
||||
func FormRequest(endpoint string, request interface{}) (*http.Request, error) {
|
||||
form := make(map[string][]string)
|
||||
encoder := schema.NewEncoder()
|
||||
if err := encoder.Encode(request, form); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := strings.NewReader(url.Values(form).Encode())
|
||||
req, err := http.NewRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func HttpRequest(client *http.Client, req *http.Request, response interface{}) error {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("http status not ok: %s %s", resp.Status, body)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %v %s", err, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) {
|
||||
values := make(map[string][]string)
|
||||
err := encoder.Encode(resp, values)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
v := url.Values(values)
|
||||
return v.Encode(), nil
|
||||
}
|
21
pkg/utils/marshal.go
Normal file
21
pkg/utils/marshal.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func MarshalJSON(w http.ResponseWriter, i interface{}) {
|
||||
b, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("content-type", "application/json")
|
||||
_, err = w.Write(b)
|
||||
if err != nil {
|
||||
logrus.Error("error writing response")
|
||||
}
|
||||
}
|
10
pkg/utils/strings.go
Normal file
10
pkg/utils/strings.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package utils
|
||||
|
||||
func Contains(list []string, needle string) bool {
|
||||
for _, item := range list {
|
||||
if item == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
48
pkg/utils/strings_test.go
Normal file
48
pkg/utils/strings_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestContains(t *testing.T) {
|
||||
type args struct {
|
||||
list []string
|
||||
needle string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"empty list false",
|
||||
args{[]string{}, "needle"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list not containing false",
|
||||
args{[]string{"list"}, "needle"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list not containing empty needle false",
|
||||
args{[]string{"list", "needle"}, ""},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"list containing true",
|
||||
args{[]string{"list", "needle"}, "needle"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"list containing empty needle true",
|
||||
args{[]string{"list", "needle", ""}, ""},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Contains(tt.args.list, tt.args.needle); got != tt.want {
|
||||
t.Errorf("Contains() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue