commit 6d0890e28013cab63010df3a68ad60616ba5fb69 Author: Livio Amstutz Date: Fri Jan 31 15:22:16 2020 +0100 initial commit diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..934fed6 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,29 @@ +name: Release +on: push +jobs: + test: + runs-on: ubuntu-18.04 + strategy: + matrix: + go: ['1.11', '1.12', '1.13'] + name: Go ${{ matrix.go }} test + steps: + - uses: actions/checkout@master + - name: Setup go + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + - run: go test -race ./pkg/... + release: + runs-on: ubuntu-18.04 + needs: [test] + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - name: Source checkout + uses: actions/checkout@v1 + with: + fetch-depth: 1 + - name: Create Version + uses: caos/semantic-release@v0.2.4 + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f94a21c --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +**/__debug_bin +.vscode +.DS_Store diff --git a/.releaserc.js b/.releaserc.js new file mode 100644 index 0000000..cf2f499 --- /dev/null +++ b/.releaserc.js @@ -0,0 +1,7 @@ +module.exports = { + branch: 'master', + plugins: [ + "@semantic-release/commit-analyzer", + "@semantic-release/release-notes-generator" + ] + }; \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cbe2479 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# oidc + +![semantic-release](https://img.shields.io/badge/%20%20%F0%9F%93%A6%F0%9F%9A%80-semantic--release-e10079.svg) +![Github Release Badge](https://github.com/caos/oidc/workflows/Release/badge.svg) +[![GitHub release](https://img.shields.io/github/release/caos/oidc)](https://GitHub.com/caos/oidc/releases/) +[![GitHub license](https://img.shields.io/github/license/caos/oidc)](https://github.com/caos/oidc/blob/master/LICENSE) + +OpenID Connect SDK (client and server) for Go diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..f7ecc88 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,42 @@ +# Security Policy + +At caos we are extremely grateful for security aware people that disclose vulnerabilities to us and the open source community. All reports will be investigated by our team. + +## Supported Versions + +After the initial Release the following version support will apply + +| Version | Supported | +| ------- | ------------------ | +| 1.x.x | :white_check_mark: (note yet available) | +| 0.x.x | :x: | + +## Reporting a vulnerability + +To file a incident, please disclose by email to security@caos.ch with the security details. + +At the moment GPG encryption is no yet supported, however you may sign your message at will. + +### When should I report a vulnerability + +* You think you discovered a ... + * ... potential security vulnerability in the SDK + * ... vulnerability in another project that this SDK bases on +* For projects with their own vulnerability reporting and disclosure process, please report it directly there + +### When should I NOT report a vulnerability + +* You need help applying security related updates +* Your issue is not security related + +## Security Vulnerability Response + +TBD + +## Public Disclosure + +All accepted and mitigated vulnerabilitys will be published on the [Github Security Page](https://github.com/caos/oidc/security/advisories) + +### Timing + +We think it is crucial to publish advisories `ASAP` as mitigations are ready. But due to the unknown nature of the discloures the time frame can range from 7 to 90 days. diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..1df929b --- /dev/null +++ b/doc.go @@ -0,0 +1 @@ +package oidc diff --git a/example/client/api/api.go b/example/client/api/api.go new file mode 100644 index 0000000..6e1b0bd --- /dev/null +++ b/example/client/api/api.go @@ -0,0 +1,90 @@ +package main + +// import ( +// "encoding/json" +// "fmt" +// "log" +// "net/http" +// "os" + +// "github.com/caos/oidc/pkg/oidc" +// "github.com/caos/oidc/pkg/oidc/rp" +// "github.com/caos/utils/logging" +// ) + +// const ( +// publicURL string = "/public" +// protectedURL string = "/protected" +// protectedExchangeURL string = "/protected/exchange" +// ) + +func main() { + // clientID := os.Getenv("CLIENT_ID") + // clientSecret := os.Getenv("CLIENT_SECRET") + // issuer := os.Getenv("ISSUER") + // port := os.Getenv("PORT") + + // // ctx := context.Background() + + // providerConfig := &oidc.ProviderConfig{ + // ClientID: clientID, + // ClientSecret: clientSecret, + // Issuer: issuer, + // } + // provider, err := rp.NewDefaultProvider(providerConfig) + // logging.Log("APP-nx6PeF").OnError(err).Panic("error creating provider") + + // http.HandleFunc(publicURL, func(w http.ResponseWriter, r *http.Request) { + // w.Write([]byte("OK")) + // }) + + // http.HandleFunc(protectedURL, func(w http.ResponseWriter, r *http.Request) { + // ok, token := checkToken(w, r) + // if !ok { + // return + // } + // resp, err := provider.Introspect(r.Context(), token) + // if err != nil { + // http.Error(w, err.Error(), http.StatusForbidden) + // return + // } + // data, err := json.Marshal(resp) + // if err != nil { + // http.Error(w, err.Error(), http.StatusInternalServerError) + // return + // } + // w.Write(data) + // }) + + // http.HandleFunc(protectedExchangeURL, func(w http.ResponseWriter, r *http.Request) { + // ok, token := checkToken(w, r) + // if !ok { + // return + // } + // tokens, err := provider.DelegationTokenExchange(r.Context(), token, oidc.WithResource([]string{"Test"})) + // if err != nil { + // http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + // return + // } + + // data, err := json.Marshal(tokens) + // if err != nil { + // http.Error(w, err.Error(), http.StatusInternalServerError) + // return + // } + // w.Write(data) + // }) + + // lis := fmt.Sprintf("127.0.0.1:%s", port) + // log.Printf("listening on http://%s/", lis) + // log.Fatal(http.ListenAndServe(lis, nil)) + // } + + // func checkToken(w http.ResponseWriter, r *http.Request) (bool, string) { + // token := r.Header.Get("authorization") + // if token == "" { + // http.Error(w, "Auth header missing", http.StatusUnauthorized) + // return false, "" + // } + // return true, token +} diff --git a/example/client/app/app.go b/example/client/app/app.go new file mode 100644 index 0000000..f1b99d7 --- /dev/null +++ b/example/client/app/app.go @@ -0,0 +1,96 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + + "github.com/sirupsen/logrus" + + "github.com/google/uuid" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/rp" + "github.com/caos/oidc/pkg/utils" +) + +var ( + callbackPath string = "/auth/callback" + key []byte = []byte("test1234test1234") +) + +func main() { + clientID := os.Getenv("CLIENT_ID") + clientSecret := os.Getenv("CLIENT_SECRET") + issuer := os.Getenv("ISSUER") + port := os.Getenv("PORT") + + ctx := context.Background() + + rpConfig := &rp.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Issuer: issuer, + CallbackURL: fmt.Sprintf("http://localhost:%v%v", port, callbackPath), + Scopes: []string{"openid", "profile", "email"}, + } + cookieHandler := utils.NewCookieHandler(key, key, utils.WithUnsecure()) + provider, err := rp.NewDefaultRP(rpConfig, rp.WithCookieHandler(cookieHandler)) //rp.WithPKCE(cookieHandler)) //, + if err != nil { + logrus.Fatalf("error creating provider %s", err.Error()) + } + + // state := "foobar" + state := uuid.New().String() + + http.Handle("/login", provider.AuthURLHandler(state)) + // http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { + // http.Redirect(w, r, provider.AuthURL(state), http.StatusFound) + // }) + + // http.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + // tokens, err := provider.CodeExchange(ctx, r.URL.Query().Get("code")) + // if err != nil { + // http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + // return + // } + // data, err := json.Marshal(tokens) + // if err != nil { + // http.Error(w, err.Error(), http.StatusInternalServerError) + // return + // } + // w.Write(data) + // }) + + marshal := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string) { + _ = state + data, err := json.Marshal(tokens) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) + } + + http.Handle(callbackPath, provider.CodeExchangeHandler(marshal)) + + http.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + tokens, err := provider.ClientCredentials(ctx, "scope") + if err != nil { + http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + return + } + + data, err := json.Marshal(tokens) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) + }) + lis := fmt.Sprintf("127.0.0.1:%s", port) + logrus.Infof("listening on http://%s/", lis) + logrus.Fatal(http.ListenAndServe("127.0.0.1:5556", nil)) +} diff --git a/example/doc.go b/example/doc.go new file mode 100644 index 0000000..f7ec372 --- /dev/null +++ b/example/doc.go @@ -0,0 +1 @@ +package example diff --git a/example/internal/mock/storage.go b/example/internal/mock/storage.go new file mode 100644 index 0000000..37f47b1 --- /dev/null +++ b/example/internal/mock/storage.go @@ -0,0 +1,250 @@ +package mock + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" +) + +type AuthStorage struct { + key *rsa.PrivateKey +} + +func NewAuthStorage() op.Storage { + reader := rand.Reader + bitSize := 2048 + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + panic(err) + } + return &AuthStorage{ + key: key, + } +} + +type AuthRequest struct { + ID string + ResponseType oidc.ResponseType + RedirectURI string + Nonce string + ClientID string + CodeChallenge *oidc.CodeChallenge +} + +func (a *AuthRequest) GetACR() string { + return "" +} + +func (a *AuthRequest) GetAMR() []string { + return []string{ + "password", + } +} + +func (a *AuthRequest) GetAudience() []string { + return []string{ + a.ClientID, + } +} + +func (a *AuthRequest) GetAuthTime() time.Time { + return time.Now().UTC() +} + +func (a *AuthRequest) GetClientID() string { + return a.ClientID +} + +func (a *AuthRequest) GetCode() string { + return "code" +} + +func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge { + return a.CodeChallenge +} + +func (a *AuthRequest) GetID() string { + return a.ID +} + +func (a *AuthRequest) GetNonce() string { + return a.Nonce +} + +func (a *AuthRequest) GetRedirectURI() string { + return a.RedirectURI + // return "http://localhost:5556/auth/callback" +} + +func (a *AuthRequest) GetResponseType() oidc.ResponseType { + return a.ResponseType +} + +func (a *AuthRequest) GetScopes() []string { + return []string{ + "openid", + "profile", + "email", + } +} + +func (a *AuthRequest) GetState() string { + return "" +} + +func (a *AuthRequest) GetSubject() string { + return "sub" +} + +func (a *AuthRequest) Done() bool { + return true +} + +var ( + a = &AuthRequest{} + t bool +) + +func (s *AuthStorage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest) (op.AuthRequest, error) { + a = &AuthRequest{ID: "id", ClientID: authReq.ClientID, ResponseType: authReq.ResponseType, Nonce: authReq.Nonce, RedirectURI: authReq.RedirectURI} + if authReq.CodeChallenge != "" { + a.CodeChallenge = &oidc.CodeChallenge{ + Challenge: authReq.CodeChallenge, + Method: authReq.CodeChallengeMethod, + } + } + t = false + return a, nil +} +func (s *AuthStorage) AuthRequestByCode(context.Context, string) (op.AuthRequest, error) { + return a, nil +} +func (s *AuthStorage) DeleteAuthRequest(context.Context, string) error { + t = true + return nil +} +func (s *AuthStorage) AuthRequestByID(_ context.Context, id string) (op.AuthRequest, error) { + if id != "id" || t { + return nil, errors.New("not found") + } + return a, nil +} +func (s *AuthStorage) GetSigningKey(_ context.Context) (*jose.SigningKey, error) { + return &jose.SigningKey{Algorithm: jose.RS256, Key: s.key}, nil +} +func (s *AuthStorage) GetKey(_ context.Context) (*rsa.PrivateKey, error) { + return s.key, nil +} +func (s *AuthStorage) SaveKeyPair(ctx context.Context) (*jose.SigningKey, error) { + return s.GetSigningKey(ctx) +} +func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { + pubkey := s.key.Public() + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + jose.JSONWebKey{Key: pubkey, Use: "sig", Algorithm: "RS256", KeyID: "1"}, + }, + }, nil +} + +func (s *AuthStorage) GetClientByClientID(_ context.Context, id string) (op.Client, error) { + if id == "none" { + return nil, errors.New("not found") + } + var appType op.ApplicationType + var authMethod op.AuthMethod + var accessTokenType op.AccessTokenType + if id == "web" { + appType = op.ApplicationTypeWeb + authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeBearer + } else if id == "native" { + appType = op.ApplicationTypeNative + authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeBearer + } else { + appType = op.ApplicationTypeUserAgent + authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeJWT + } + return &ConfClient{ID: id, applicationType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil +} + +func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, id string, _ string) error { + return nil +} + +func (s *AuthStorage) GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) { + return &oidc.Userinfo{ + Subject: a.GetSubject(), + Address: &oidc.UserinfoAddress{ + StreetAddress: "Hjkhkj 789\ndsf", + }, + UserinfoEmail: oidc.UserinfoEmail{ + Email: "test", + EmailVerified: true, + }, + UserinfoPhone: oidc.UserinfoPhone{ + PhoneNumber: "sadsa", + PhoneNumberVerified: true, + }, + UserinfoProfile: oidc.UserinfoProfile{ + UpdatedAt: time.Now(), + }, + // Claims: map[string]interface{}{ + // "test": "test", + // "hkjh": "", + // }, + }, nil +} + +type ConfClient struct { + applicationType op.ApplicationType + authMethod op.AuthMethod + ID string + accessTokenType op.AccessTokenType +} + +func (c *ConfClient) GetID() string { + return c.ID +} +func (c *ConfClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://localhost:9999/callback", + "http://localhost:5556/auth/callback", + "custom://callback", + "https://localhost:8443/test/a/instructions-example/callback", + "https://op.certification.openid.net:62064/authz_cb", + "https://op.certification.openid.net:62064/authz_post", + } +} + +func (c *ConfClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *ConfClient) ApplicationType() op.ApplicationType { + return c.applicationType +} + +func (c *ConfClient) GetAuthMethod() op.AuthMethod { + return c.authMethod +} + +func (c *ConfClient) AccessTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) IDTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} diff --git a/example/server/default/default.go b/example/server/default/default.go new file mode 100644 index 0000000..0b0bb8e --- /dev/null +++ b/example/server/default/default.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + "crypto/sha256" + "html/template" + "log" + "net/http" + + "github.com/gorilla/mux" + + "github.com/caos/oidc/example/internal/mock" + "github.com/caos/oidc/pkg/op" +) + +func main() { + ctx := context.Background() + config := &op.Config{ + Issuer: "http://localhost:9998/", + CryptoKey: sha256.Sum256([]byte("test")), + Port: "9998", + } + storage := mock.NewAuthStorage() + handler, err := op.NewDefaultOP(ctx, config, storage, op.WithCustomTokenEndpoint("test")) + if err != nil { + log.Fatal(err) + } + router := handler.HttpHandler().Handler.(*mux.Router) + router.Methods("GET").Path("/login").HandlerFunc(HandleLogin) + router.Methods("POST").Path("/login").HandlerFunc(HandleCallback) + op.Start(ctx, handler) + <-ctx.Done() +} + +func HandleLogin(w http.ResponseWriter, r *http.Request) { + tpl := ` + + + + + Login + + +
+ + +
+ + ` + t, err := template.New("login").Parse(tpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = t.Execute(w, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func HandleCallback(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + client := r.FormValue("client") + http.Redirect(w, r, "/authorize/"+client, http.StatusFound) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..da7059a --- /dev/null +++ b/go.mod @@ -0,0 +1,25 @@ +module github.com/caos/oidc + +go 1.13 + +require ( + github.com/golang/mock v1.3.1 + github.com/golang/protobuf v1.3.2 // indirect + github.com/google/uuid v1.1.1 + github.com/gorilla/mux v1.7.3 + github.com/gorilla/schema v1.1.0 + github.com/gorilla/securecookie v1.1.1 + github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/testify v1.4.0 + golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 // indirect + golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 + golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c + golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 // indirect + golang.org/x/text v0.3.2 + google.golang.org/appengine v1.6.5 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/square/go-jose.v2 v2.4.0 + gopkg.in/yaml.v2 v2.2.3 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..54a2ca8 --- /dev/null +++ b/go.sum @@ -0,0 +1,74 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/schema v1.1.0 h1:CamqUDOFUBqzrvxuz2vEwo8+SUdwsluFh7IlzJh30LY= +github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1 h1:anGSYQpPhQwXlwsu5wmfq0nWkCNaMEMUwAv13Y92hd8= +golang.org/x/crypto v0.0.0-20191128160524-b544559bb6d1/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933 h1:e6HwijUxhDe+hPNjZQQn9bA5PW3vNmnN64U2ZW759Lk= +golang.org/x/net v0.0.0-20191126235420-ef20fe5d7933/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c h1:HjRaKPaiWks0f5tA6ELVF7ZfqSppfPwOEEAvsrKUTO4= +golang.org/x/oauth2 v0.0.0-20191122200657-5d9234df094c/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9 h1:ZBzSG/7F4eNKz2L3GE9o300RX0Az1Bw5HF7PDraD+qU= +golang.org/x/sys v0.0.0-20191128015809-6d18c012aee9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A= +gopkg.in/square/go-jose.v2 v2.4.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3 h1:fvjTMHxHEw/mxHbtzPi3JCcKXQRAnQTBRo6YCJSVHKI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/pkg/oidc/authorization.go b/pkg/oidc/authorization.go new file mode 100644 index 0000000..02c5603 --- /dev/null +++ b/pkg/oidc/authorization.go @@ -0,0 +1,151 @@ +package oidc + +import ( + "errors" + "strings" + + "golang.org/x/text/language" +) + +const ( + ScopeOpenID = "openid" + + ResponseTypeCode ResponseType = "code" + ResponseTypeIDToken ResponseType = "id_token token" + ResponseTypeIDTokenOnly ResponseType = "id_token" + + DisplayPage Display = "page" + DisplayPopup Display = "popup" + DisplayTouch Display = "touch" + DisplayWAP Display = "wap" + + PromptNone Prompt = "none" + PromptLogin Prompt = "login" + PromptConsent Prompt = "consent" + PromptSelectAccount Prompt = "select_account" + + GrantTypeCode GrantType = "authorization_code" + + BearerToken = "Bearer" +) + +var displayValues = map[string]Display{ + "page": DisplayPage, + "popup": DisplayPopup, + "touch": DisplayTouch, + "wap": DisplayWAP, +} + +//AuthRequest according to: +//https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest +// +type AuthRequest struct { + ID string + Scopes Scopes `schema:"scope"` + ResponseType ResponseType `schema:"response_type"` + ClientID string `schema:"client_id"` + RedirectURI string `schema:"redirect_uri"` //TODO: type + + State string `schema:"state"` + + // ResponseMode TODO: ? + + Nonce string `schema:"nonce"` + Display Display `schema:"display"` + Prompt Prompt `schema:"prompt"` + MaxAge uint32 `schema:"max_age"` + UILocales Locales `schema:"ui_locales"` + IDTokenHint string `schema:"id_token_hint"` + LoginHint string `schema:"login_hint"` + ACRValues []string `schema:"acr_values"` + + CodeChallenge string `schema:"code_challenge"` + CodeChallengeMethod CodeChallengeMethod `schema:"code_challenge_method"` +} + +func (a *AuthRequest) GetRedirectURI() string { + return a.RedirectURI +} +func (a *AuthRequest) GetResponseType() ResponseType { + return a.ResponseType +} +func (a *AuthRequest) GetState() string { + return a.State +} + +type TokenRequest interface { + // GrantType GrantType `schema:"grant_type"` + GrantType() GrantType +} + +type TokenRequestType GrantType + +type AccessTokenRequest struct { + Code string `schema:"code"` + RedirectURI string `schema:"redirect_uri"` + ClientID string `schema:"client_id"` + ClientSecret string `schema:"client_secret"` + CodeVerifier string `schema:"code_verifier"` +} + +func (a *AccessTokenRequest) GrantType() GrantType { + return GrantTypeCode +} + +type AccessTokenResponse struct { + AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"` + TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"` + ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"` + IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"` +} + +type TokenExchangeRequest struct { + subjectToken string `schema:"subject_token"` + subjectTokenType string `schema:"subject_token_type"` + actorToken string `schema:"actor_token"` + actorTokenType string `schema:"actor_token_type"` + resource []string `schema:"resource"` + audience []string `schema:"audience"` + Scope []string `schema:"scope"` + requestedTokenType string `schema:"requested_token_type"` +} + +type Scopes []string + +func (s *Scopes) UnmarshalText(text []byte) error { + scopes := strings.Split(string(text), " ") + *s = Scopes(scopes) + return nil +} + +type ResponseType string + +type Display string + +func (d *Display) UnmarshalText(text []byte) error { + var ok bool + display := string(text) + *d, ok = displayValues[display] + if !ok { + return errors.New("") + } + return nil +} + +type Prompt string + +type Locales []language.Tag + +func (l *Locales) UnmarshalText(text []byte) error { + locales := strings.Split(string(text), " ") + for _, locale := range locales { + tag, err := language.Parse(locale) + if err == nil && !tag.IsRoot() { + *l = append(*l, tag) + } + } + return nil +} + +type GrantType string diff --git a/pkg/oidc/code_challenge.go b/pkg/oidc/code_challenge.go new file mode 100644 index 0000000..e3035c2 --- /dev/null +++ b/pkg/oidc/code_challenge.go @@ -0,0 +1,33 @@ +package oidc + +import ( + "crypto/sha256" + + "github.com/caos/oidc/pkg/utils" +) + +const ( + CodeChallengeMethodPlain CodeChallengeMethod = "plain" + CodeChallengeMethodS256 CodeChallengeMethod = "S256" +) + +type CodeChallengeMethod string + +type CodeChallenge struct { + Challenge string + Method CodeChallengeMethod +} + +func NewSHACodeChallenge(code string) string { + return utils.HashString(sha256.New(), code) +} + +func VerifyCodeChallenge(c *CodeChallenge, codeVerifier string) bool { + if c == nil { + return false //TODO: ? + } + if c.Method == CodeChallengeMethodS256 { + codeVerifier = NewSHACodeChallenge(codeVerifier) + } + return codeVerifier == c.Challenge +} diff --git a/pkg/oidc/discovery.go b/pkg/oidc/discovery.go new file mode 100644 index 0000000..5d2875e --- /dev/null +++ b/pkg/oidc/discovery.go @@ -0,0 +1,24 @@ +package oidc + +const ( + DiscoveryEndpoint = "/.well-known/openid-configuration" +) + +type DiscoveryConfiguration struct { + Issuer string `json:"issuer,omitempty"` + AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` + IntrospectionEndpoint string `json:"introspection_endpoint,omitempty"` + UserinfoEndpoint string `json:"userinfo_endpoint,omitempty"` + EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` + CheckSessionIframe string `json:"check_session_iframe,omitempty"` + JwksURI string `json:"jwks_uri,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported,omitempty"` + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + ClaimsSupported []string `json:"claims_supported,omitempty"` +} diff --git a/pkg/oidc/grants/client_credentials.go b/pkg/oidc/grants/client_credentials.go new file mode 100644 index 0000000..998dda1 --- /dev/null +++ b/pkg/oidc/grants/client_credentials.go @@ -0,0 +1,33 @@ +package grants + +import "strings" + +type clientCredentialsGrantBasic struct { + grantType string `schema:"grant_type"` + scope string `schema:"scope"` +} + +type clientCredentialsGrant struct { + *clientCredentialsGrantBasic + clientID string `schema:"client_id"` + clientSecret string `schema:"client_secret"` +} + +//ClientCredentialsGrantBasic creates an oauth2 `Client Credentials` Grant +//sneding client_id and client_secret as basic auth header +func ClientCredentialsGrantBasic(scopes ...string) *clientCredentialsGrantBasic { + return &clientCredentialsGrantBasic{ + grantType: "client_credentials", + scope: strings.Join(scopes, " "), + } +} + +//ClientCredentialsGrantValues creates an oauth2 `Client Credentials` Grant +//sneding client_id and client_secret as form values +func ClientCredentialsGrantValues(clientID, clientSecret string, scopes ...string) *clientCredentialsGrant { + return &clientCredentialsGrant{ + clientCredentialsGrantBasic: ClientCredentialsGrantBasic(scopes...), + clientID: clientID, + clientSecret: clientSecret, + } +} diff --git a/pkg/oidc/grants/tokenexchange/tokenexchange.go b/pkg/oidc/grants/tokenexchange/tokenexchange.go new file mode 100644 index 0000000..02a9808 --- /dev/null +++ b/pkg/oidc/grants/tokenexchange/tokenexchange.go @@ -0,0 +1,75 @@ +package tokenexchange + +const ( + AccessTokenType = "urn:ietf:params:oauth:token-type:access_token" + RefreshTokenType = "urn:ietf:params:oauth:token-type:refresh_token" + IDTokenType = "urn:ietf:params:oauth:token-type:id_token" + JWTTokenType = "urn:ietf:params:oauth:token-type:jwt" + DelegationTokenType = AccessTokenType + + TokenExchangeGrantType = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +type TokenExchangeRequest struct { + grantType string `schema:"grant_type"` + subjectToken string `schema:"subject_token"` + subjectTokenType string `schema:"subject_token_type"` + actorToken string `schema:"actor_token"` + actorTokenType string `schema:"actor_token_type"` + resource []string `schema:"resource"` + audience []string `schema:"audience"` + scope []string `schema:"scope"` + requestedTokenType string `schema:"requested_token_type"` +} + +func NewTokenExchangeRequest(subjectToken, subjectTokenType string, opts ...TokenExchangeOption) *TokenExchangeRequest { + t := &TokenExchangeRequest{ + grantType: TokenExchangeGrantType, + subjectToken: subjectToken, + subjectTokenType: subjectTokenType, + requestedTokenType: AccessTokenType, + } + for _, opt := range opts { + opt(t) + } + return t +} + +type TokenExchangeOption func(*TokenExchangeRequest) + +func WithActorToken(token, tokenType string) func(*TokenExchangeRequest) { + return func(req *TokenExchangeRequest) { + req.actorToken = token + req.actorTokenType = tokenType + } +} + +func WithAudience(audience []string) func(*TokenExchangeRequest) { + return func(req *TokenExchangeRequest) { + req.audience = audience + } +} + +func WithGrantType(grantType string) TokenExchangeOption { + return func(req *TokenExchangeRequest) { + req.grantType = grantType + } +} + +func WithRequestedTokenType(tokenType string) func(*TokenExchangeRequest) { + return func(req *TokenExchangeRequest) { + req.requestedTokenType = tokenType + } +} + +func WithResource(resource []string) func(*TokenExchangeRequest) { + return func(req *TokenExchangeRequest) { + req.resource = resource + } +} + +func WithScope(scope []string) func(*TokenExchangeRequest) { + return func(req *TokenExchangeRequest) { + req.scope = scope + } +} diff --git a/pkg/oidc/keyset.go b/pkg/oidc/keyset.go new file mode 100644 index 0000000..f9bed2f --- /dev/null +++ b/pkg/oidc/keyset.go @@ -0,0 +1,22 @@ +package oidc + +import ( + "context" + + "gopkg.in/square/go-jose.v2" +) + +// KeySet is a set of publc JSON Web Keys that can be used to validate the signature +// of JSON web tokens. This is expected to be backed by a remote key set through +// provider metadata discovery or an in-memory set of keys delivered out-of-band. +type KeySet interface { + // VerifySignature parses the JSON web token, verifies the signature, and returns + // the raw payload. Header and claim fields are validated by other parts of the + // package. For example, the KeySet does not need to check values such as signature + // algorithm, issuer, and audience since the IDTokenVerifier validates these values + // independently. + // + // If VerifySignature makes HTTP requests to verify the token, it's expected to + // use any HTTP client associated with the context through ClientContext. + VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (payload []byte, err error) +} diff --git a/pkg/oidc/token.go b/pkg/oidc/token.go new file mode 100644 index 0000000..6f1496f --- /dev/null +++ b/pkg/oidc/token.go @@ -0,0 +1,196 @@ +package oidc + +import ( + "encoding/json" + "strings" + "time" + + "github.com/caos/oidc/pkg/utils" + "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" +) + +type Tokens struct { + *oauth2.Token + IDTokenClaims *IDTokenClaims + IDToken string +} + +type AccessTokenClaims struct { + Issuer string + Subject string + Audiences []string + Expiration time.Time + IssuedAt time.Time + NotBefore time.Time + JWTID string + AuthorizedParty string + Nonce string + AuthTime time.Time + CodeHash string + AuthenticationContextClassReference string + AuthenticationMethodsReferences []string + SessionID string + Scopes []string + ClientID string + AccessTokenUseNumber int +} + +type IDTokenClaims struct { + Issuer string + Subject string + Audiences []string + Expiration time.Time + NotBefore time.Time + IssuedAt time.Time + JWTID string + UpdatedAt time.Time + AuthorizedParty string + Nonce string + AuthTime time.Time + AccessTokenHash string + CodeHash string + AuthenticationContextClassReference string + AuthenticationMethodsReferences []string + ClientID string + + Signature jose.SignatureAlgorithm //TODO: ??? +} + +type jsonToken struct { + Issuer string `json:"iss,omitempty"` + Subject string `json:"sub,omitempty"` + Audiences []string `json:"aud,omitempty"` + Expiration int64 `json:"exp,omitempty"` + NotBefore int64 `json:"nbf,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` + JWTID string `json:"jti,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + AuthorizedParty string `json:"azp,omitempty"` + Nonce string `json:"nonce,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + CodeHash string `json:"c_hash,omitempty"` + AuthenticationContextClassReference string `json:"acr,omitempty"` + AuthenticationMethodsReferences []string `json:"amr,omitempty"` + SessionID string `json:"sid,omitempty"` + Actor interface{} `json:"act,omitempty"` //TODO: impl + Scopes string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + AuthorizedActor interface{} `json:"may_act,omitempty"` //TODO: impl + AccessTokenUseNumber int `json:"at_use_nbr,omitempty"` +} + +func (t *AccessTokenClaims) MarshalJSON() ([]byte, error) { + j := jsonToken{ + Issuer: t.Issuer, + Subject: t.Subject, + Audiences: t.Audiences, + Expiration: timeToJSON(t.Expiration), + NotBefore: timeToJSON(t.NotBefore), + IssuedAt: timeToJSON(t.IssuedAt), + JWTID: t.JWTID, + AuthorizedParty: t.AuthorizedParty, + Nonce: t.Nonce, + AuthTime: timeToJSON(t.AuthTime), + CodeHash: t.CodeHash, + AuthenticationContextClassReference: t.AuthenticationContextClassReference, + AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, + SessionID: t.SessionID, + Scopes: strings.Join(t.Scopes, " "), + ClientID: t.ClientID, + AccessTokenUseNumber: t.AccessTokenUseNumber, + } + return json.Marshal(j) +} + +func (t *AccessTokenClaims) UnmarshalJSON(b []byte) error { + var j jsonToken + if err := json.Unmarshal(b, &j); err != nil { + return err + } + audience := j.Audiences + if len(audience) == 1 { + audience = strings.Split(audience[0], " ") + } + t.Issuer = j.Issuer + t.Subject = j.Subject + t.Audiences = audience + t.Expiration = time.Unix(j.Expiration, 0).UTC() + t.NotBefore = time.Unix(j.NotBefore, 0).UTC() + t.IssuedAt = time.Unix(j.IssuedAt, 0).UTC() + t.JWTID = j.JWTID + t.AuthorizedParty = j.AuthorizedParty + t.Nonce = j.Nonce + t.AuthTime = time.Unix(j.AuthTime, 0).UTC() + t.CodeHash = j.CodeHash + t.AuthenticationContextClassReference = j.AuthenticationContextClassReference + t.AuthenticationMethodsReferences = j.AuthenticationMethodsReferences + t.SessionID = j.SessionID + t.Scopes = strings.Split(j.Scopes, " ") + t.ClientID = j.ClientID + t.AccessTokenUseNumber = j.AccessTokenUseNumber + return nil +} + +func (t *IDTokenClaims) MarshalJSON() ([]byte, error) { + j := jsonToken{ + Issuer: t.Issuer, + Subject: t.Subject, + Audiences: t.Audiences, + Expiration: timeToJSON(t.Expiration), + NotBefore: timeToJSON(t.NotBefore), + IssuedAt: timeToJSON(t.IssuedAt), + JWTID: t.JWTID, + UpdatedAt: timeToJSON(t.UpdatedAt), + AuthorizedParty: t.AuthorizedParty, + Nonce: t.Nonce, + AuthTime: timeToJSON(t.AuthTime), + AccessTokenHash: t.AccessTokenHash, + CodeHash: t.CodeHash, + AuthenticationContextClassReference: t.AuthenticationContextClassReference, + AuthenticationMethodsReferences: t.AuthenticationMethodsReferences, + ClientID: t.ClientID, + } + return json.Marshal(j) +} + +func (t *IDTokenClaims) UnmarshalJSON(b []byte) error { + var i jsonToken + if err := json.Unmarshal(b, &i); err != nil { + return err + } + audience := i.Audiences + if len(audience) == 1 { + audience = strings.Split(audience[0], " ") + } + t.Issuer = i.Issuer + t.Subject = i.Subject + t.Audiences = audience + t.Expiration = time.Unix(i.Expiration, 0).UTC() + t.IssuedAt = time.Unix(i.IssuedAt, 0).UTC() + t.AuthTime = time.Unix(i.AuthTime, 0).UTC() + t.Nonce = i.Nonce + t.AuthenticationContextClassReference = i.AuthenticationContextClassReference + t.AuthenticationMethodsReferences = i.AuthenticationMethodsReferences + t.AuthorizedParty = i.AuthorizedParty + t.AccessTokenHash = i.AccessTokenHash + t.CodeHash = i.CodeHash + return nil +} + +func ClaimHash(claim string, sigAlgorithm jose.SignatureAlgorithm) (string, error) { + hash, err := utils.GetHashAlgorithm(sigAlgorithm) + if err != nil { + return "", err + } + + return utils.HashString(hash, claim), nil +} + +func timeToJSON(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.Unix() +} diff --git a/pkg/oidc/userinfo.go b/pkg/oidc/userinfo.go new file mode 100644 index 0000000..5e99d09 --- /dev/null +++ b/pkg/oidc/userinfo.go @@ -0,0 +1,120 @@ +package oidc + +import ( + "encoding/json" + "time" + + "golang.org/x/text/language" +) + +type Userinfo struct { + Subject string + Address *UserinfoAddress + UserinfoProfile + UserinfoEmail + UserinfoPhone + + claims map[string]interface{} +} + +type UserinfoPhone struct { + PhoneNumber string + PhoneNumberVerified bool +} +type UserinfoProfile struct { + Name string + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender Gender + Birthdate string + Zoneinfo string + Locale language.Tag + UpdatedAt time.Time + PreferredUsername string +} + +type Gender string + +type UserinfoAddress struct { + Formatted string + StreetAddress string + Locality string + Region string + PostalCode string + Country string +} + +type UserinfoEmail struct { + Email string + EmailVerified bool +} + +func marshalUserinfoProfile(i UserinfoProfile, claims map[string]interface{}) { + claims["name"] = i.Name + claims["given_name"] = i.GivenName + claims["family_name"] = i.FamilyName + claims["middle_name"] = i.MiddleName + claims["nickname"] = i.Nickname + claims["profile"] = i.Profile + claims["picture"] = i.Picture + claims["website"] = i.Website + claims["gender"] = i.Gender + claims["birthdate"] = i.Birthdate + claims["Zoneinfo"] = i.Zoneinfo + claims["locale"] = i.Locale.String() + claims["updated_at"] = i.UpdatedAt.UTC().Unix() + claims["preferred_username"] = i.PreferredUsername +} + +func marshalUserinfoEmail(i UserinfoEmail, claims map[string]interface{}) { + if i.Email != "" { + claims["email"] = i.Email + } + if i.EmailVerified { + claims["email_verified"] = i.EmailVerified + } +} + +func marshalUserinfoAddress(i *UserinfoAddress, claims map[string]interface{}) { + if i == nil { + return + } + address := make(map[string]interface{}) + if i.Formatted != "" { + address["formatted"] = i.Formatted + } + if i.StreetAddress != "" { + address["street_address"] = i.StreetAddress + } + claims["address"] = address +} + +func marshalUserinfoPhone(i UserinfoPhone, claims map[string]interface{}) { + claims["phone_number"] = i.PhoneNumber + claims["phone_number_verified"] = i.PhoneNumberVerified +} + +func (i *Userinfo) MarshalJSON() ([]byte, error) { + claims := i.claims + if claims == nil { + claims = make(map[string]interface{}) + } + claims["sub"] = i.Subject + marshalUserinfoAddress(i.Address, claims) + marshalUserinfoEmail(i.UserinfoEmail, claims) + marshalUserinfoPhone(i.UserinfoPhone, claims) + marshalUserinfoProfile(i.UserinfoProfile, claims) + return json.Marshal(claims) +} + +func (i *Userinfo) UnmmarshalJSON(data []byte) error { + if err := json.Unmarshal(data, i); err != nil { + return err + } + return json.Unmarshal(data, i.claims) +} diff --git a/pkg/op/authrequest.go b/pkg/op/authrequest.go new file mode 100644 index 0000000..9f9505d --- /dev/null +++ b/pkg/op/authrequest.go @@ -0,0 +1,198 @@ +package op + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/gorilla/mux" + "github.com/gorilla/schema" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" +) + +type Authorizer interface { + Storage() Storage + Decoder() *schema.Decoder + Encoder() *schema.Encoder + Signer() Signer + Crypto() Crypto + Issuer() string +} + +type ValidationAuthorizer interface { + Authorizer + ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage) error +} + +func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + err := r.ParseForm() + if err != nil { + AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder()) + return + } + authReq := new(oidc.AuthRequest) + err = authorizer.Decoder().Decode(authReq, r.Form) + if err != nil { + AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder()) + return + } + validation := ValidateAuthRequest + if validater, ok := authorizer.(ValidationAuthorizer); ok { + validation = validater.ValidateAuthRequest + } + if err := validation(r.Context(), authReq, authorizer.Storage()); err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + req, err := authorizer.Storage().CreateAuthRequest(r.Context(), authReq) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + client, err := authorizer.Storage().GetClientByClientID(r.Context(), req.GetClientID()) + if err != nil { + AuthRequestError(w, r, req, err, authorizer.Encoder()) + return + } + RedirectToLogin(req.GetID(), client, w, r) +} + +func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage) error { + if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { + return err + } + if err := ValidateAuthReqRedirectURI(ctx, authReq.RedirectURI, authReq.ClientID, authReq.ResponseType, storage); err != nil { + return err + } + if err := ValidateAuthReqResponseType(authReq.ResponseType); err != nil { + return err + } + // if NeedsExistingSession(authReq) { + // session, err := storage.CheckSession(authReq.IDTokenHint) + // if err != nil { + // return err + // } + // } + return nil +} + +func ValidateAuthReqScopes(scopes []string) error { + if len(scopes) == 0 { + return ErrInvalidRequest("scope missing") + } + if !utils.Contains(scopes, oidc.ScopeOpenID) { + return ErrInvalidRequest("scope openid missing") + } + return nil +} + +func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, responseType oidc.ResponseType, storage OPStorage) error { + if uri == "" { + return ErrInvalidRequestRedirectURI("redirect_uri must not be empty") + } + client, err := storage.GetClientByClientID(ctx, client_id) + if err != nil { + return ErrServerError(err.Error()) + } + if !utils.Contains(client.RedirectURIs(), uri) { + return ErrInvalidRequestRedirectURI("redirect_uri not allowed") + } + if strings.HasPrefix(uri, "https://") { + return nil + } + if responseType == oidc.ResponseTypeCode { + if strings.HasPrefix(uri, "http://") && IsConfidentialType(client) { + return nil + } + if client.ApplicationType() == ApplicationTypeNative { + return nil + } + return ErrInvalidRequest("redirect_uri not allowed") + } else { + if client.ApplicationType() != ApplicationTypeNative { + return ErrInvalidRequestRedirectURI("redirect_uri not allowed") + } + if !(strings.HasPrefix(uri, "http://localhost:") || strings.HasPrefix(uri, "http://localhost/")) { + return ErrInvalidRequestRedirectURI("redirect_uri not allowed") + } + } + return nil +} + +func ValidateAuthReqResponseType(responseType oidc.ResponseType) error { + if responseType == "" { + return ErrInvalidRequest("response_type empty") + } + return nil +} + +func RedirectToLogin(authReqID string, client Client, w http.ResponseWriter, r *http.Request) { + login := client.LoginURL(authReqID) + http.Redirect(w, r, login, http.StatusFound) +} + +func AuthorizeCallback(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { + params := mux.Vars(r) + id := params["id"] + + authReq, err := authorizer.Storage().AuthRequestByID(r.Context(), id) + if err != nil { + AuthRequestError(w, r, nil, err, authorizer.Encoder()) + return + } + if !authReq.Done() { + AuthRequestError(w, r, authReq, errors.New("user not logged in"), authorizer.Encoder()) + return + } + AuthResponse(authReq, authorizer, w, r) +} + +func AuthResponse(authReq AuthRequest, authorizer Authorizer, w http.ResponseWriter, r *http.Request) { + client, err := authorizer.Storage().GetClientByClientID(r.Context(), authReq.GetClientID()) + if err != nil { + + } + if authReq.GetResponseType() == oidc.ResponseTypeCode { + AuthResponseCode(w, r, authReq, authorizer) + return + } + AuthResponseToken(w, r, authReq, authorizer, client) + return +} + +func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer) { + code, err := BuildAuthRequestCode(authReq, authorizer.Crypto()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + callback := fmt.Sprintf("%s?code=%s", authReq.GetRedirectURI(), code) + if authReq.GetState() != "" { + callback = callback + "&state=" + authReq.GetState() + } + http.Redirect(w, r, callback, http.StatusFound) +} + +func AuthResponseToken(w http.ResponseWriter, r *http.Request, authReq AuthRequest, authorizer Authorizer, client Client) { + createAccessToken := authReq.GetResponseType() != oidc.ResponseTypeIDTokenOnly + resp, err := CreateTokenResponse(authReq, client, authorizer, createAccessToken, "") + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + params, err := utils.URLEncodeResponse(resp, authorizer.Encoder()) + if err != nil { + AuthRequestError(w, r, authReq, err, authorizer.Encoder()) + return + } + callback := fmt.Sprintf("%s#%s", authReq.GetRedirectURI(), params) + http.Redirect(w, r, callback, http.StatusFound) +} + +func BuildAuthRequestCode(authReq AuthRequest, crypto Crypto) (string, error) { + return crypto.Encrypt(authReq.GetID()) +} diff --git a/pkg/op/authrequest_test.go b/pkg/op/authrequest_test.go new file mode 100644 index 0000000..b0599c3 --- /dev/null +++ b/pkg/op/authrequest_test.go @@ -0,0 +1,296 @@ +package op_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/op/mock" +) + +func TestAuthorize(t *testing.T) { + // testCallback := func(t *testing.T, clienID string) callbackHandler { + // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { + // // require.Equal(t, clientID, client.) + // } + // } + // testErr := func(t *testing.T, expected error) errorHandler { + // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { + // require.Equal(t, expected, err) + // } + // } + type args struct { + w http.ResponseWriter + r *http.Request + authorizer op.Authorizer + } + tests := []struct { + name string + args args + }{ + { + "parsing fails", + args{ + httptest.NewRecorder(), + &http.Request{Method: "POST", Body: nil}, + mock.NewAuthorizerExpectValid(t, true), + // testCallback(t, ""), + // testErr(t, ErrInvalidRequest("cannot parse form")), + }, + }, + { + "decoding fails", + args{ + httptest.NewRecorder(), + func() *http.Request { + r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return r + }(), + mock.NewAuthorizerExpectValid(t, true), + // testCallback(t, ""), + // testErr(t, ErrInvalidRequest("cannot parse auth request")), + }, + }, + // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) + }) + } +} + +func TestValidateAuthRequest(t *testing.T) { + type args struct { + authRequest *oidc.AuthRequest + storage op.Storage + } + tests := []struct { + name string + args args + wantErr bool + }{ + //TODO: + // { + // "oauth2 spec" + // } + { + "scope missing fails", + args{&oidc.AuthRequest{}, nil}, + true, + }, + { + "scope openid missing fails", + args{&oidc.AuthRequest{Scopes: []string{"profile"}}, nil}, + true, + }, + { + "response_type missing fails", + args{&oidc.AuthRequest{Scopes: []string{"openid"}}, nil}, + true, + }, + { + "client_id missing fails", + args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, nil}, + true, + }, + { + "redirect_uri missing fails", + args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, nil}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := op.ValidateAuthRequest(nil, tt.args.authRequest, tt.args.storage); (err != nil) != tt.wantErr { + t.Errorf("ValidateAuthRequest() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateAuthReqScopes(t *testing.T) { + type args struct { + scopes []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "scopes missing fails", args{}, true, + }, + { + "scope openid missing fails", args{[]string{"email"}}, true, + }, + { + "scope ok", args{[]string{"openid"}}, false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := op.ValidateAuthReqScopes(tt.args.scopes); (err != nil) != tt.wantErr { + t.Errorf("ValidateAuthReqScopes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateAuthReqRedirectURI(t *testing.T) { + type args struct { + uri string + clientID string + responseType oidc.ResponseType + storage op.OPStorage + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "empty fails", + args{"", "", oidc.ResponseTypeCode, nil}, + true, + }, + { + "unregistered fails", + args{"https://unregistered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "storage error fails", + args{"https://registered.com/callback", "non_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectInvalidClientID(t)}, + true, + }, + { + "code flow registered http not confidential fails", + args{"http://registered.com/callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "code flow registered http confidential ok", + args{"http://registered.com/callback", "web_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + false, + }, + { + "code flow registered custom not native fails", + args{"custom://callback", "useragent_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "code flow registered custom native ok", + args{"http://registered.com/callback", "native_client", oidc.ResponseTypeCode, mock.NewMockStorageExpectValidClientID(t)}, + false, + }, + { + "implicit flow registered ok", + args{"https://registered.com/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + false, + }, + { + "implicit flow registered http localhost native ok", + args{"http://localhost:9999/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + false, + }, + { + "implicit flow registered http localhost user agent fails", + args{"http://localhost:9999/callback", "useragent_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "implicit flow http non localhost fails", + args{"http://registered.com/callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + { + "implicit flow custom fails", + args{"custom://callback", "native_client", oidc.ResponseTypeIDToken, mock.NewMockStorageExpectValidClientID(t)}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := op.ValidateAuthReqRedirectURI(nil, tt.args.uri, tt.args.clientID, tt.args.responseType, tt.args.storage); (err != nil) != tt.wantErr { + t.Errorf("ValidateRedirectURI() error = %v, wantErr %v", err.Error(), tt.wantErr) + } + }) + } +} + +func TestRedirectToLogin(t *testing.T) { + type args struct { + authReqID string + client op.Client + w http.ResponseWriter + r *http.Request + } + tests := []struct { + name string + args args + }{ + { + "redirect ok", + args{ + "id", + mock.NewClientExpectAny(t, op.ApplicationTypeNative), + httptest.NewRecorder(), + httptest.NewRequest("GET", "/authorize", nil), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.RedirectToLogin(tt.args.authReqID, tt.args.client, tt.args.w, tt.args.r) + rec := tt.args.w.(*httptest.ResponseRecorder) + require.Equal(t, http.StatusFound, rec.Code) + require.Equal(t, "/login?id=id", rec.Header().Get("location")) + }) + } +} + +func TestAuthorizeCallback(t *testing.T) { + type args struct { + w http.ResponseWriter + r *http.Request + authorizer op.Authorizer + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.AuthorizeCallback(tt.args.w, tt.args.r, tt.args.authorizer) + }) + } +} + +func TestAuthResponse(t *testing.T) { + type args struct { + authReq op.AuthRequest + authorizer op.Authorizer + w http.ResponseWriter + r *http.Request + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.AuthResponse(tt.args.authReq, tt.args.authorizer, tt.args.w, tt.args.r) + }) + } +} diff --git a/pkg/op/client.go b/pkg/op/client.go new file mode 100644 index 0000000..33c30a4 --- /dev/null +++ b/pkg/op/client.go @@ -0,0 +1,33 @@ +package op + +import "time" + +const ( + ApplicationTypeWeb ApplicationType = iota + ApplicationTypeUserAgent + ApplicationTypeNative + + AccessTokenTypeBearer AccessTokenType = iota + AccessTokenTypeJWT +) + +type Client interface { + GetID() string + RedirectURIs() []string + ApplicationType() ApplicationType + GetAuthMethod() AuthMethod + LoginURL(string) string + AccessTokenType() AccessTokenType + AccessTokenLifetime() time.Duration + IDTokenLifetime() time.Duration +} + +func IsConfidentialType(c Client) bool { + return c.ApplicationType() == ApplicationTypeWeb +} + +type ApplicationType int + +type AuthMethod string + +type AccessTokenType int diff --git a/pkg/op/config.go b/pkg/op/config.go new file mode 100644 index 0000000..9333a5c --- /dev/null +++ b/pkg/op/config.go @@ -0,0 +1,54 @@ +package op + +import ( + "errors" + "net/url" + "os" + "strings" +) + +type Configuration interface { + Issuer() string + AuthorizationEndpoint() Endpoint + TokenEndpoint() Endpoint + UserinfoEndpoint() Endpoint + KeysEndpoint() Endpoint + + AuthMethodPostSupported() bool + + Port() string +} + +func ValidateIssuer(issuer string) error { + if issuer == "" { + return errors.New("missing issuer") + } + u, err := url.Parse(issuer) + if err != nil { + return errors.New("invalid url for issuer") + } + if u.Host == "" { + return errors.New("host for issuer missing") + } + if u.Scheme != "https" { + if !devLocalAllowed(u) { + return errors.New("scheme for issuer must be `https`") + } + } + if u.Fragment != "" || len(u.Query()) > 0 { + return errors.New("no fragments or query allowed for issuer") + } + return nil +} + +func devLocalAllowed(url *url.URL) bool { + _, b := os.LookupEnv("CAOS_OIDC_DEV") + if !b { + return b + } + return url.Scheme == "http" && + url.Host == "localhost" || + url.Host == "127.0.0.1" || + url.Host == "::1" || + strings.HasPrefix(url.Host, "localhost:") +} diff --git a/pkg/op/config_test.go b/pkg/op/config_test.go new file mode 100644 index 0000000..56cf2eb --- /dev/null +++ b/pkg/op/config_test.go @@ -0,0 +1,94 @@ +package op + +import "testing" + +import "os" + +func TestValidateIssuer(t *testing.T) { + type args struct { + issuer string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "missing issuer fails", + args{""}, + true, + }, + { + "invalid url for issuer fails", + args{":issuer"}, + true, + }, + { + "invalid url for issuer fails", + args{":issuer"}, + true, + }, + { + "host for issuer missing fails", + args{"https:///issuer"}, + true, + }, + { + "host for not https fails", + args{"http://issuer.com"}, + true, + }, + { + "host with fragment fails", + args{"https://issuer.com/#issuer"}, + true, + }, + { + "host with query fails", + args{"https://issuer.com?issuer=me"}, + true, + }, + { + "host with https ok", + args{"https://issuer.com"}, + false, + }, + { + "localhost with http ok", + args{"http://localhost:9999"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { + t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateIssuerDevLocalAllowed(t *testing.T) { + type args struct { + issuer string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "localhost with http ok", + args{"http://localhost:9999"}, + false, + }, + } + os.Setenv("CAOS_OIDC_DEV", "") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateIssuer(tt.args.issuer); (err != nil) != tt.wantErr { + t.Errorf("ValidateIssuer() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/op/crypto.go b/pkg/op/crypto.go new file mode 100644 index 0000000..e95157d --- /dev/null +++ b/pkg/op/crypto.go @@ -0,0 +1,26 @@ +package op + +import ( + "github.com/caos/oidc/pkg/utils" +) + +type Crypto interface { + Encrypt(string) (string, error) + Decrypt(string) (string, error) +} + +type aesCrypto struct { + key string +} + +func NewAESCrypto(key [32]byte) Crypto { + return &aesCrypto{key: string(key[:32])} +} + +func (c *aesCrypto) Encrypt(s string) (string, error) { + return utils.EncryptAES(s, c.key) +} + +func (c *aesCrypto) Decrypt(s string) (string, error) { + return utils.DecryptAES(s, c.key) +} diff --git a/pkg/op/default_op.go b/pkg/op/default_op.go new file mode 100644 index 0000000..8d140b5 --- /dev/null +++ b/pkg/op/default_op.go @@ -0,0 +1,224 @@ +package op + +import ( + "context" + "net/http" + + "github.com/gorilla/schema" + + "github.com/caos/oidc/pkg/oidc" +) + +const ( + defaultAuthorizationEndpoint = "authorize" + defaulTokenEndpoint = "oauth/token" + defaultIntrospectEndpoint = "introspect" + defaultUserinfoEndpoint = "userinfo" + defaultKeysEndpoint = "keys" + + AuthMethodBasic AuthMethod = "client_secret_basic" + AuthMethodPost = "client_secret_post" + AuthMethodNone = "none" +) + +var ( + DefaultEndpoints = &endpoints{ + Authorization: defaultAuthorizationEndpoint, + Token: defaulTokenEndpoint, + IntrospectionEndpoint: defaultIntrospectEndpoint, + Userinfo: defaultUserinfoEndpoint, + JwksURI: defaultKeysEndpoint, + } +) + +type DefaultOP struct { + config *Config + endpoints *endpoints + discoveryConfig *oidc.DiscoveryConfiguration + storage Storage + signer Signer + crypto Crypto + http *http.Server + decoder *schema.Decoder + encoder *schema.Encoder +} + +type Config struct { + Issuer string + CryptoKey [32]byte + // ScopesSupported: oidc.SupportedScopes, + // ResponseTypesSupported: responseTypes, + // GrantTypesSupported: oidc.SupportedGrantTypes, + // ClaimsSupported: oidc.SupportedClaims, + // IdTokenSigningAlgValuesSupported: []string{keys.SigningAlgorithm}, + // SubjectTypesSupported: []string{"public"}, + // TokenEndpointAuthMethodsSupported: + Port string +} + +type endpoints struct { + Authorization Endpoint + Token Endpoint + IntrospectionEndpoint Endpoint + Userinfo Endpoint + EndSessionEndpoint Endpoint + CheckSessionIframe Endpoint + JwksURI Endpoint +} + +type DefaultOPOpts func(o *DefaultOP) error + +func WithCustomAuthEndpoint(endpoint Endpoint) DefaultOPOpts { + return func(o *DefaultOP) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.Authorization = endpoint + return nil + } +} + +func WithCustomTokenEndpoint(endpoint Endpoint) DefaultOPOpts { + return func(o *DefaultOP) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.Token = endpoint + return nil + } +} + +func WithCustomUserinfoEndpoint(endpoint Endpoint) DefaultOPOpts { + return func(o *DefaultOP) error { + if err := endpoint.Validate(); err != nil { + return err + } + o.endpoints.Userinfo = endpoint + return nil + } +} + +func NewDefaultOP(ctx context.Context, config *Config, storage Storage, opOpts ...DefaultOPOpts) (OpenIDProvider, error) { + err := ValidateIssuer(config.Issuer) + if err != nil { + return nil, err + } + + p := &DefaultOP{ + config: config, + storage: storage, + endpoints: DefaultEndpoints, + } + + p.signer, err = NewDefaultSigner(ctx, storage) + if err != nil { + return nil, err + } + + for _, optFunc := range opOpts { + if err := optFunc(p); err != nil { + return nil, err + } + } + + p.discoveryConfig = CreateDiscoveryConfig(p, p.signer) + + router := CreateRouter(p) + p.http = &http.Server{ + Addr: ":" + config.Port, + Handler: router, + } + p.decoder = schema.NewDecoder() + p.decoder.IgnoreUnknownKeys(true) + + p.encoder = schema.NewEncoder() + + p.crypto = NewAESCrypto(config.CryptoKey) + + return p, nil +} + +func (p *DefaultOP) Issuer() string { + return p.config.Issuer +} + +func (p *DefaultOP) AuthorizationEndpoint() Endpoint { + return p.endpoints.Authorization +} + +func (p *DefaultOP) TokenEndpoint() Endpoint { + return Endpoint(p.endpoints.Token) +} + +func (p *DefaultOP) UserinfoEndpoint() Endpoint { + return Endpoint(p.endpoints.Userinfo) +} + +func (p *DefaultOP) KeysEndpoint() Endpoint { + return Endpoint(p.endpoints.JwksURI) +} + +func (p *DefaultOP) AuthMethodPostSupported() bool { + return true //TODO: config +} + +func (p *DefaultOP) Port() string { + return p.config.Port +} + +func (p *DefaultOP) HttpHandler() *http.Server { + return p.http +} + +func (p *DefaultOP) HandleDiscovery(w http.ResponseWriter, r *http.Request) { + Discover(w, p.discoveryConfig) +} + +func (p *DefaultOP) Decoder() *schema.Decoder { + return p.decoder +} + +func (p *DefaultOP) Encoder() *schema.Encoder { + return p.encoder +} + +func (p *DefaultOP) Storage() Storage { + return p.storage +} + +func (p *DefaultOP) Signer() Signer { + return p.signer +} + +func (p *DefaultOP) Crypto() Crypto { + return p.crypto +} + +func (p *DefaultOP) HandleKeys(w http.ResponseWriter, r *http.Request) { + Keys(w, r, p) +} + +func (p *DefaultOP) HandleAuthorize(w http.ResponseWriter, r *http.Request) { + Authorize(w, r, p) +} + +func (p *DefaultOP) HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) { + AuthorizeCallback(w, r, p) +} + +func (p *DefaultOP) HandleExchange(w http.ResponseWriter, r *http.Request) { + reqType := r.FormValue("grant_type") + if reqType == "" { + ExchangeRequestError(w, r, ErrInvalidRequest("grant_type missing")) + return + } + if reqType == string(oidc.GrantTypeCode) { + CodeExchange(w, r, p) + return + } + TokenExchange(w, r, p) +} + +func (p *DefaultOP) HandleUserinfo(w http.ResponseWriter, r *http.Request) { + Userinfo(w, r, p) +} diff --git a/pkg/op/default_op_test.go b/pkg/op/default_op_test.go new file mode 100644 index 0000000..ed359a5 --- /dev/null +++ b/pkg/op/default_op_test.go @@ -0,0 +1,49 @@ +package op + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/caos/oidc/pkg/oidc" +) + +func TestDefaultOP_HandleDiscovery(t *testing.T) { + type fields struct { + config *Config + endpoints *endpoints + discoveryConfig *oidc.DiscoveryConfiguration + storage Storage + http *http.Server + } + type args struct { + w http.ResponseWriter + r *http.Request + } + tests := []struct { + name string + fields fields + args args + want string + wantCode int + }{ + {"OK", fields{config: nil, endpoints: nil, discoveryConfig: &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}}, args{httptest.NewRecorder(), nil}, `{"issuer":"https://issuer.com"}`, 200}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &DefaultOP{ + config: tt.fields.config, + endpoints: tt.fields.endpoints, + discoveryConfig: tt.fields.discoveryConfig, + storage: tt.fields.storage, + http: tt.fields.http, + } + p.HandleDiscovery(tt.args.w, tt.args.r) + rec := tt.args.w.(*httptest.ResponseRecorder) + require.Equal(t, tt.want, rec.Body.String()) + require.Equal(t, tt.wantCode, rec.Code) + }) + } +} diff --git a/pkg/op/discovery.go b/pkg/op/discovery.go new file mode 100644 index 0000000..3d4ea98 --- /dev/null +++ b/pkg/op/discovery.go @@ -0,0 +1,119 @@ +package op + +import ( + "net/http" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" +) + +func Discover(w http.ResponseWriter, config *oidc.DiscoveryConfiguration) { + utils.MarshalJSON(w, config) +} + +func CreateDiscoveryConfig(c Configuration, s Signer) *oidc.DiscoveryConfiguration { + return &oidc.DiscoveryConfiguration{ + Issuer: c.Issuer(), + AuthorizationEndpoint: c.AuthorizationEndpoint().Absolute(c.Issuer()), + TokenEndpoint: c.TokenEndpoint().Absolute(c.Issuer()), + // IntrospectionEndpoint: c.Intro().Absolute(c.Issuer()), + UserinfoEndpoint: c.UserinfoEndpoint().Absolute(c.Issuer()), + // EndSessionEndpoint: c.TokenEndpoint().Absolute(c.Issuer())(c.EndSessionEndpoint), + // CheckSessionIframe: c.TokenEndpoint().Absolute(c.Issuer())(c.CheckSessionIframe), + JwksURI: c.KeysEndpoint().Absolute(c.Issuer()), + ScopesSupported: Scopes(c), + ResponseTypesSupported: ResponseTypes(c), + GrantTypesSupported: GrantTypes(c), + ClaimsSupported: SupportedClaims(c), + IDTokenSigningAlgValuesSupported: SigAlgorithms(s), + SubjectTypesSupported: SubjectTypes(c), + TokenEndpointAuthMethodsSupported: AuthMethods(c), + } +} + +const ( + ScopeOpenID = "openid" + ScopeProfile = "profile" + ScopeEmail = "email" + ScopePhone = "phone" + ScopeAddress = "address" +) + +var DefaultSupportedScopes = []string{ + ScopeOpenID, + ScopeProfile, + ScopeEmail, + ScopePhone, + ScopeAddress, +} + +func Scopes(c Configuration) []string { + return DefaultSupportedScopes //TODO: config +} + +func ResponseTypes(c Configuration) []string { + return []string{ + "code", + "id_token", + // "code token", + // "code id_token", + "id_token token", + // "code id_token token" + } +} + +func GrantTypes(c Configuration) []string { + return []string{ + "client_credentials", + "authorization_code", + // "password", + "urn:ietf:params:oauth:grant-type:token-exchange", + } +} + +func SupportedClaims(c Configuration) []string { + return []string{ //TODO: config + "sub", + "aud", + "exp", + "iat", + "iss", + "auth_time", + "nonce", + "acr", + "amr", + "c_hash", + "at_hash", + "act", + "scopes", + "client_id", + "azp", + "preferred_username", + "name", + "family_name", + "given_name", + "locale", + "email", + "email_verified", + "phone_number", + "phone_number_verified", + } +} + +func SigAlgorithms(s Signer) []string { + return []string{string(s.SignatureAlgorithm())} +} + +func SubjectTypes(c Configuration) []string { + return []string{"public"} //TODO: config +} + +func AuthMethods(c Configuration) []string { + authMethods := []string{ + string(AuthMethodBasic), + } + if c.AuthMethodPostSupported() { + authMethods = append(authMethods, string(AuthMethodPost)) + } + return authMethods +} diff --git a/pkg/op/discovery_test.go b/pkg/op/discovery_test.go new file mode 100644 index 0000000..39b39bc --- /dev/null +++ b/pkg/op/discovery_test.go @@ -0,0 +1,235 @@ +package op_test + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" + "github.com/caos/oidc/pkg/op/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" +) + +func TestDiscover(t *testing.T) { + type args struct { + w http.ResponseWriter + config *oidc.DiscoveryConfiguration + } + tests := []struct { + name string + args args + }{ + { + "OK", + args{ + httptest.NewRecorder(), + &oidc.DiscoveryConfiguration{Issuer: "https://issuer.com"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op.Discover(tt.args.w, tt.args.config) + rec := tt.args.w.(*httptest.ResponseRecorder) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, `{"issuer":"https://issuer.com"}`, rec.Body.String()) + }) + } +} + +func TestCreateDiscoveryConfig(t *testing.T) { + type args struct { + c op.Configuration + s op.Signer + } + tests := []struct { + name string + args args + want *oidc.DiscoveryConfiguration + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.CreateDiscoveryConfig(tt.args.c, tt.args.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateDiscoveryConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_scopes(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + { + "default Scopes", + args{}, + op.DefaultSupportedScopes, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.Scopes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("scopes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_ResponseTypes(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.ResponseTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("responseTypes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_GrantTypes(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.GrantTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("grantTypes() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSupportedClaims(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.SupportedClaims(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("SupportedClaims() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_SigAlgorithms(t *testing.T) { + m := mock.NewMockSigner(gomock.NewController((t))) + type args struct { + s op.Signer + } + tests := []struct { + name string + args args + want []string + }{ + { + "", + args{func() op.Signer { + m.EXPECT().SignatureAlgorithm().Return(jose.RS256) + return m + }()}, + []string{"RS256"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.SigAlgorithms(tt.args.s); !reflect.DeepEqual(got, tt.want) { + t.Errorf("sigAlgorithms() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_SubjectTypes(t *testing.T) { + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + { + "none", + args{}, + []string{"public"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.SubjectTypes(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("subjectTypes() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_AuthMethods(t *testing.T) { + m := mock.NewMockConfiguration(gomock.NewController((t))) + type args struct { + c op.Configuration + } + tests := []struct { + name string + args args + want []string + }{ + { + "imlicit basic", + args{func() op.Configuration { + m.EXPECT().AuthMethodPostSupported().Return(false) + return m + }()}, + []string{string(op.AuthMethodBasic)}, + }, + { + "basic and post", + args{func() op.Configuration { + m.EXPECT().AuthMethodPostSupported().Return(true) + return m + }()}, + []string{string(op.AuthMethodBasic), string(op.AuthMethodPost)}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := op.AuthMethods(tt.args.c); !reflect.DeepEqual(got, tt.want) { + t.Errorf("authMethods() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/op/endpoint.go b/pkg/op/endpoint.go new file mode 100644 index 0000000..cc0419c --- /dev/null +++ b/pkg/op/endpoint.go @@ -0,0 +1,25 @@ +package op + +import "strings" + +type Endpoint string + +func (e Endpoint) Relative() string { + return relativeEndpoint(string(e)) +} + +func (e Endpoint) Absolute(host string) string { + return absoluteEndpoint(host, string(e)) +} + +func (e Endpoint) Validate() error { + return nil //TODO: +} + +func absoluteEndpoint(host, endpoint string) string { + return strings.TrimSuffix(host, "/") + relativeEndpoint(endpoint) +} + +func relativeEndpoint(endpoint string) string { + return "/" + strings.TrimPrefix(endpoint, "/") +} diff --git a/pkg/op/endpoint_test.go b/pkg/op/endpoint_test.go new file mode 100644 index 0000000..227bf9d --- /dev/null +++ b/pkg/op/endpoint_test.go @@ -0,0 +1,95 @@ +package op_test + +import ( + "testing" + + "github.com/caos/oidc/pkg/op" +) + +func TestEndpoint_Relative(t *testing.T) { + tests := []struct { + name string + e op.Endpoint + want string + }{ + { + "without starting /", + op.Endpoint("test"), + "/test", + }, + { + "with starting /", + op.Endpoint("/test"), + "/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.Relative(); got != tt.want { + t.Errorf("Endpoint.Relative() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEndpoint_Absolute(t *testing.T) { + type args struct { + host string + } + tests := []struct { + name string + e op.Endpoint + args args + want string + }{ + { + "no /", + op.Endpoint("test"), + args{"https://host"}, + "https://host/test", + }, + { + "endpoint without /", + op.Endpoint("test"), + args{"https://host/"}, + "https://host/test", + }, + { + "host without /", + op.Endpoint("/test"), + args{"https://host"}, + "https://host/test", + }, + { + "both /", + op.Endpoint("/test"), + args{"https://host/"}, + "https://host/test", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.e.Absolute(tt.args.host); got != tt.want { + t.Errorf("Endpoint.Absolute() = %v, want %v", got, tt.want) + } + }) + } +} + +//TODO: impl test +func TestEndpoint_Validate(t *testing.T) { + tests := []struct { + name string + e op.Endpoint + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.e.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Endpoint.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/op/error.go b/pkg/op/error.go new file mode 100644 index 0000000..1e84c1a --- /dev/null +++ b/pkg/op/error.go @@ -0,0 +1,99 @@ +package op + +import ( + "fmt" + "net/http" + + "github.com/gorilla/schema" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" +) + +const ( + InvalidRequest errorType = "invalid_request" + ServerError errorType = "server_error" +) + +var ( + ErrInvalidRequest = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: InvalidRequest, + Description: description, + } + } + ErrInvalidRequestRedirectURI = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: InvalidRequest, + Description: description, + redirectDisabled: true, + } + } + ErrServerError = func(description string) *OAuthError { + return &OAuthError{ + ErrorType: ServerError, + Description: description, + } + } +) + +type errorType string + +type ErrAuthRequest interface { + GetRedirectURI() string + GetResponseType() oidc.ResponseType + GetState() string +} + +func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) { + if authReq == nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + e, ok := err.(*OAuthError) + if !ok { + e = new(OAuthError) + e.ErrorType = ServerError + e.Description = err.Error() + } + e.state = authReq.GetState() + if authReq.GetRedirectURI() == "" || e.redirectDisabled { + http.Error(w, e.Description, http.StatusBadRequest) + return + } + params, err := utils.URLEncodeResponse(e, encoder) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + url := authReq.GetRedirectURI() + responseType := authReq.GetResponseType() + if responseType == "" || responseType == oidc.ResponseTypeCode { + url += "?" + params + } else { + url += "#" + params + } + http.Redirect(w, r, url, http.StatusFound) +} + +func ExchangeRequestError(w http.ResponseWriter, r *http.Request, err error) { + e, ok := err.(*OAuthError) + if !ok { + e = new(OAuthError) + e.ErrorType = ServerError + e.Description = err.Error() + } + w.WriteHeader(http.StatusBadRequest) + utils.MarshalJSON(w, e) +} + +type OAuthError struct { + ErrorType errorType `json:"error" schema:"error"` + Description string `json:"description" schema:"description"` + state string `json:"state" schema:"state"` + redirectDisabled bool +} + +func (e *OAuthError) Error() string { + return fmt.Sprintf("%s: %s", e.ErrorType, e.Description) +} diff --git a/pkg/op/keys.go b/pkg/op/keys.go new file mode 100644 index 0000000..8e2052b --- /dev/null +++ b/pkg/op/keys.go @@ -0,0 +1,19 @@ +package op + +import ( + "net/http" + + "github.com/caos/oidc/pkg/utils" +) + +type KeyProvider interface { + Storage() Storage +} + +func Keys(w http.ResponseWriter, r *http.Request, k KeyProvider) { + keySet, err := k.Storage().GetKeySet(r.Context()) + if err != nil { + + } + utils.MarshalJSON(w, keySet) +} diff --git a/pkg/op/mock/authorizer.mock.go b/pkg/op/mock/authorizer.mock.go new file mode 100644 index 0000000..48f9aed --- /dev/null +++ b/pkg/op/mock/authorizer.mock.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: Authorizer) + +// Package mock is a generated GoMock package. +package mock + +import ( + op "github.com/caos/oidc/pkg/op" + gomock "github.com/golang/mock/gomock" + schema "github.com/gorilla/schema" + reflect "reflect" +) + +// MockAuthorizer is a mock of Authorizer interface +type MockAuthorizer struct { + ctrl *gomock.Controller + recorder *MockAuthorizerMockRecorder +} + +// MockAuthorizerMockRecorder is the mock recorder for MockAuthorizer +type MockAuthorizerMockRecorder struct { + mock *MockAuthorizer +} + +// NewMockAuthorizer creates a new mock instance +func NewMockAuthorizer(ctrl *gomock.Controller) *MockAuthorizer { + mock := &MockAuthorizer{ctrl: ctrl} + mock.recorder = &MockAuthorizerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAuthorizer) EXPECT() *MockAuthorizerMockRecorder { + return m.recorder +} + +// Crypto mocks base method +func (m *MockAuthorizer) Crypto() op.Crypto { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Crypto") + ret0, _ := ret[0].(op.Crypto) + return ret0 +} + +// Crypto indicates an expected call of Crypto +func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Crypto", reflect.TypeOf((*MockAuthorizer)(nil).Crypto)) +} + +// Decoder mocks base method +func (m *MockAuthorizer) Decoder() *schema.Decoder { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decoder") + ret0, _ := ret[0].(*schema.Decoder) + return ret0 +} + +// Decoder indicates an expected call of Decoder +func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decoder", reflect.TypeOf((*MockAuthorizer)(nil).Decoder)) +} + +// Encoder mocks base method +func (m *MockAuthorizer) Encoder() *schema.Encoder { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encoder") + ret0, _ := ret[0].(*schema.Encoder) + return ret0 +} + +// Encoder indicates an expected call of Encoder +func (mr *MockAuthorizerMockRecorder) Encoder() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encoder", reflect.TypeOf((*MockAuthorizer)(nil).Encoder)) +} + +// Issuer mocks base method +func (m *MockAuthorizer) Issuer() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Issuer") + ret0, _ := ret[0].(string) + return ret0 +} + +// Issuer indicates an expected call of Issuer +func (mr *MockAuthorizerMockRecorder) Issuer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockAuthorizer)(nil).Issuer)) +} + +// Signer mocks base method +func (m *MockAuthorizer) Signer() op.Signer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Signer") + ret0, _ := ret[0].(op.Signer) + return ret0 +} + +// Signer indicates an expected call of Signer +func (mr *MockAuthorizerMockRecorder) Signer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Signer", reflect.TypeOf((*MockAuthorizer)(nil).Signer)) +} + +// Storage mocks base method +func (m *MockAuthorizer) Storage() op.Storage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Storage") + ret0, _ := ret[0].(op.Storage) + return ret0 +} + +// Storage indicates an expected call of Storage +func (mr *MockAuthorizerMockRecorder) Storage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockAuthorizer)(nil).Storage)) +} diff --git a/pkg/op/mock/authorizer.mock.impl.go b/pkg/op/mock/authorizer.mock.impl.go new file mode 100644 index 0000000..0091877 --- /dev/null +++ b/pkg/op/mock/authorizer.mock.impl.go @@ -0,0 +1,89 @@ +package mock + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/gorilla/schema" + "gopkg.in/square/go-jose.v2" + + oidc "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/op" +) + +func NewAuthorizer(t *testing.T) op.Authorizer { + return NewMockAuthorizer(gomock.NewController(t)) +} + +func NewAuthorizerExpectValid(t *testing.T, wantErr bool) op.Authorizer { + m := NewAuthorizer(t) + ExpectDecoder(m) + ExpectEncoder(m) + ExpectSigner(m, t) + ExpectStorage(m, t) + // ExpectErrorHandler(m, t, wantErr) + return m +} + +// func NewAuthorizerExpectDecoderFails(t *testing.T) op.Authorizer { +// m := NewAuthorizer(t) +// ExpectDecoderFails(m) +// ExpectEncoder(m) +// ExpectSigner(m, t) +// ExpectStorage(m, t) +// ExpectErrorHandler(m, t) +// return m +// } + +func ExpectDecoder(a op.Authorizer) { + mockA := a.(*MockAuthorizer) + mockA.EXPECT().Decoder().AnyTimes().Return(schema.NewDecoder()) +} + +func ExpectEncoder(a op.Authorizer) { + mockA := a.(*MockAuthorizer) + mockA.EXPECT().Encoder().AnyTimes().Return(schema.NewEncoder()) +} + +func ExpectSigner(a op.Authorizer, t *testing.T) { + mockA := a.(*MockAuthorizer) + mockA.EXPECT().Signer().DoAndReturn( + func() op.Signer { + return &Sig{} + }) +} + +// func ExpectErrorHandler(a op.Authorizer, t *testing.T, wantErr bool) { +// mockA := a.(*MockAuthorizer) +// mockA.EXPECT().ErrorHandler().AnyTimes(). +// Return(func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { +// if wantErr { +// require.Error(t, err) +// return +// } +// require.NoError(t, err) +// }) +// } + +type Sig struct{} + +func (s *Sig) SignIDToken(*oidc.IDTokenClaims) (string, error) { + return "", nil +} +func (s *Sig) SignAccessToken(*oidc.AccessTokenClaims) (string, error) { + return "", nil +} +func (s *Sig) SignatureAlgorithm() jose.SignatureAlgorithm { + return jose.HS256 +} + +func ExpectStorage(a op.Authorizer, t *testing.T) { + mockA := a.(*MockAuthorizer) + mockA.EXPECT().Storage().AnyTimes().Return(NewMockStorageAny(t)) +} + +// func NewMockSignerAny(t *testing.T) op.Signer { +// m := NewMockSigner(gomock.NewController(t)) +// m.EXPECT().Sign(gomock.Any()).AnyTimes().Return("", nil) +// return m +// } diff --git a/pkg/op/mock/client.go b/pkg/op/mock/client.go new file mode 100644 index 0000000..242eb13 --- /dev/null +++ b/pkg/op/mock/client.go @@ -0,0 +1,29 @@ +package mock + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + + op "github.com/caos/oidc/pkg/op" +) + +func NewClient(t *testing.T) op.Client { + return NewMockClient(gomock.NewController(t)) +} + +func NewClientExpectAny(t *testing.T, appType op.ApplicationType) op.Client { + c := NewClient(t) + m := c.(*MockClient) + m.EXPECT().RedirectURIs().AnyTimes().Return([]string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback"}) + m.EXPECT().ApplicationType().AnyTimes().Return(appType) + m.EXPECT().LoginURL(gomock.Any()).AnyTimes().DoAndReturn( + func(id string) string { + return "login?id=" + id + }) + return c +} diff --git a/pkg/op/mock/client.mock.go b/pkg/op/mock/client.mock.go new file mode 100644 index 0000000..9ae2201 --- /dev/null +++ b/pkg/op/mock/client.mock.go @@ -0,0 +1,147 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: Client) + +// Package mock is a generated GoMock package. +package mock + +import ( + op "github.com/caos/oidc/pkg/op" + gomock "github.com/golang/mock/gomock" + reflect "reflect" + time "time" +) + +// MockClient is a mock of Client interface +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// AccessTokenLifetime mocks base method +func (m *MockClient) AccessTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// AccessTokenLifetime indicates an expected call of AccessTokenLifetime +func (mr *MockClientMockRecorder) AccessTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenLifetime", reflect.TypeOf((*MockClient)(nil).AccessTokenLifetime)) +} + +// AccessTokenType mocks base method +func (m *MockClient) AccessTokenType() op.AccessTokenType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AccessTokenType") + ret0, _ := ret[0].(op.AccessTokenType) + return ret0 +} + +// AccessTokenType indicates an expected call of AccessTokenType +func (mr *MockClientMockRecorder) AccessTokenType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessTokenType", reflect.TypeOf((*MockClient)(nil).AccessTokenType)) +} + +// ApplicationType mocks base method +func (m *MockClient) ApplicationType() op.ApplicationType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplicationType") + ret0, _ := ret[0].(op.ApplicationType) + return ret0 +} + +// ApplicationType indicates an expected call of ApplicationType +func (mr *MockClientMockRecorder) ApplicationType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplicationType", reflect.TypeOf((*MockClient)(nil).ApplicationType)) +} + +// GetAuthMethod mocks base method +func (m *MockClient) GetAuthMethod() op.AuthMethod { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthMethod") + ret0, _ := ret[0].(op.AuthMethod) + return ret0 +} + +// GetAuthMethod indicates an expected call of GetAuthMethod +func (mr *MockClientMockRecorder) GetAuthMethod() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthMethod", reflect.TypeOf((*MockClient)(nil).GetAuthMethod)) +} + +// GetID mocks base method +func (m *MockClient) GetID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetID indicates an expected call of GetID +func (mr *MockClientMockRecorder) GetID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetID", reflect.TypeOf((*MockClient)(nil).GetID)) +} + +// IDTokenLifetime mocks base method +func (m *MockClient) IDTokenLifetime() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IDTokenLifetime") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// IDTokenLifetime indicates an expected call of IDTokenLifetime +func (mr *MockClientMockRecorder) IDTokenLifetime() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IDTokenLifetime", reflect.TypeOf((*MockClient)(nil).IDTokenLifetime)) +} + +// LoginURL mocks base method +func (m *MockClient) LoginURL(arg0 string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoginURL", arg0) + ret0, _ := ret[0].(string) + return ret0 +} + +// LoginURL indicates an expected call of LoginURL +func (mr *MockClientMockRecorder) LoginURL(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginURL", reflect.TypeOf((*MockClient)(nil).LoginURL), arg0) +} + +// RedirectURIs mocks base method +func (m *MockClient) RedirectURIs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RedirectURIs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// RedirectURIs indicates an expected call of RedirectURIs +func (mr *MockClientMockRecorder) RedirectURIs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RedirectURIs", reflect.TypeOf((*MockClient)(nil).RedirectURIs)) +} diff --git a/pkg/op/mock/configuration.mock.go b/pkg/op/mock/configuration.mock.go new file mode 100644 index 0000000..7148c6d --- /dev/null +++ b/pkg/op/mock/configuration.mock.go @@ -0,0 +1,132 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: Configuration) + +// Package mock is a generated GoMock package. +package mock + +import ( + op "github.com/caos/oidc/pkg/op" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockConfiguration is a mock of Configuration interface +type MockConfiguration struct { + ctrl *gomock.Controller + recorder *MockConfigurationMockRecorder +} + +// MockConfigurationMockRecorder is the mock recorder for MockConfiguration +type MockConfigurationMockRecorder struct { + mock *MockConfiguration +} + +// NewMockConfiguration creates a new mock instance +func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration { + mock := &MockConfiguration{ctrl: ctrl} + mock.recorder = &MockConfigurationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder { + return m.recorder +} + +// AuthMethodPostSupported mocks base method +func (m *MockConfiguration) AuthMethodPostSupported() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthMethodPostSupported") + ret0, _ := ret[0].(bool) + return ret0 +} + +// AuthMethodPostSupported indicates an expected call of AuthMethodPostSupported +func (mr *MockConfigurationMockRecorder) AuthMethodPostSupported() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthMethodPostSupported", reflect.TypeOf((*MockConfiguration)(nil).AuthMethodPostSupported)) +} + +// AuthorizationEndpoint mocks base method +func (m *MockConfiguration) AuthorizationEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthorizationEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// AuthorizationEndpoint indicates an expected call of AuthorizationEndpoint +func (mr *MockConfigurationMockRecorder) AuthorizationEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizationEndpoint", reflect.TypeOf((*MockConfiguration)(nil).AuthorizationEndpoint)) +} + +// Issuer mocks base method +func (m *MockConfiguration) Issuer() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Issuer") + ret0, _ := ret[0].(string) + return ret0 +} + +// Issuer indicates an expected call of Issuer +func (mr *MockConfigurationMockRecorder) Issuer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Issuer", reflect.TypeOf((*MockConfiguration)(nil).Issuer)) +} + +// KeysEndpoint mocks base method +func (m *MockConfiguration) KeysEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeysEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// KeysEndpoint indicates an expected call of KeysEndpoint +func (mr *MockConfigurationMockRecorder) KeysEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeysEndpoint", reflect.TypeOf((*MockConfiguration)(nil).KeysEndpoint)) +} + +// Port mocks base method +func (m *MockConfiguration) Port() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Port") + ret0, _ := ret[0].(string) + return ret0 +} + +// Port indicates an expected call of Port +func (mr *MockConfigurationMockRecorder) Port() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Port", reflect.TypeOf((*MockConfiguration)(nil).Port)) +} + +// TokenEndpoint mocks base method +func (m *MockConfiguration) TokenEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TokenEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// TokenEndpoint indicates an expected call of TokenEndpoint +func (mr *MockConfigurationMockRecorder) TokenEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TokenEndpoint", reflect.TypeOf((*MockConfiguration)(nil).TokenEndpoint)) +} + +// UserinfoEndpoint mocks base method +func (m *MockConfiguration) UserinfoEndpoint() op.Endpoint { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UserinfoEndpoint") + ret0, _ := ret[0].(op.Endpoint) + return ret0 +} + +// UserinfoEndpoint indicates an expected call of UserinfoEndpoint +func (mr *MockConfigurationMockRecorder) UserinfoEndpoint() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserinfoEndpoint", reflect.TypeOf((*MockConfiguration)(nil).UserinfoEndpoint)) +} diff --git a/pkg/op/mock/generate.go b/pkg/op/mock/generate.go new file mode 100644 index 0000000..beb3132 --- /dev/null +++ b/pkg/op/mock/generate.go @@ -0,0 +1,7 @@ +package mock + +//go:generate mockgen -package mock -destination ./storage.mock.go github.com/caos/oidc/pkg/op Storage +//go:generate mockgen -package mock -destination ./authorizer.mock.go github.com/caos/oidc/pkg/op Authorizer +//go:generate mockgen -package mock -destination ./client.mock.go github.com/caos/oidc/pkg/op Client +//go:generate mockgen -package mock -destination ./configuration.mock.go github.com/caos/oidc/pkg/op Configuration +//go:generate mockgen -package mock -destination ./signer.mock.go github.com/caos/oidc/pkg/op Signer diff --git a/pkg/op/mock/signer.mock.go b/pkg/op/mock/signer.mock.go new file mode 100644 index 0000000..5c7b669 --- /dev/null +++ b/pkg/op/mock/signer.mock.go @@ -0,0 +1,79 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: Signer) + +// Package mock is a generated GoMock package. +package mock + +import ( + oidc "github.com/caos/oidc/pkg/oidc" + gomock "github.com/golang/mock/gomock" + go_jose_v2 "gopkg.in/square/go-jose.v2" + reflect "reflect" +) + +// MockSigner is a mock of Signer interface +type MockSigner struct { + ctrl *gomock.Controller + recorder *MockSignerMockRecorder +} + +// MockSignerMockRecorder is the mock recorder for MockSigner +type MockSignerMockRecorder struct { + mock *MockSigner +} + +// NewMockSigner creates a new mock instance +func NewMockSigner(ctrl *gomock.Controller) *MockSigner { + mock := &MockSigner{ctrl: ctrl} + mock.recorder = &MockSignerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSigner) EXPECT() *MockSignerMockRecorder { + return m.recorder +} + +// SignAccessToken mocks base method +func (m *MockSigner) SignAccessToken(arg0 *oidc.AccessTokenClaims) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SignAccessToken", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignAccessToken indicates an expected call of SignAccessToken +func (mr *MockSignerMockRecorder) SignAccessToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignAccessToken", reflect.TypeOf((*MockSigner)(nil).SignAccessToken), arg0) +} + +// SignIDToken mocks base method +func (m *MockSigner) SignIDToken(arg0 *oidc.IDTokenClaims) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SignIDToken", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SignIDToken indicates an expected call of SignIDToken +func (mr *MockSignerMockRecorder) SignIDToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignIDToken", reflect.TypeOf((*MockSigner)(nil).SignIDToken), arg0) +} + +// SignatureAlgorithm mocks base method +func (m *MockSigner) SignatureAlgorithm() go_jose_v2.SignatureAlgorithm { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SignatureAlgorithm") + ret0, _ := ret[0].(go_jose_v2.SignatureAlgorithm) + return ret0 +} + +// SignatureAlgorithm indicates an expected call of SignatureAlgorithm +func (mr *MockSignerMockRecorder) SignatureAlgorithm() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SignatureAlgorithm", reflect.TypeOf((*MockSigner)(nil).SignatureAlgorithm)) +} diff --git a/pkg/op/mock/storage.mock.go b/pkg/op/mock/storage.mock.go new file mode 100644 index 0000000..3a36417 --- /dev/null +++ b/pkg/op/mock/storage.mock.go @@ -0,0 +1,170 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/caos/oidc/pkg/op (interfaces: Storage) + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + oidc "github.com/caos/oidc/pkg/oidc" + op "github.com/caos/oidc/pkg/op" + gomock "github.com/golang/mock/gomock" + go_jose_v2 "gopkg.in/square/go-jose.v2" + reflect "reflect" +) + +// MockStorage is a mock of Storage interface +type MockStorage struct { + ctrl *gomock.Controller + recorder *MockStorageMockRecorder +} + +// MockStorageMockRecorder is the mock recorder for MockStorage +type MockStorageMockRecorder struct { + mock *MockStorage +} + +// NewMockStorage creates a new mock instance +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { + return m.recorder +} + +// AuthRequestByID mocks base method +func (m *MockStorage) AuthRequestByID(arg0 context.Context, arg1 string) (op.AuthRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthRequestByID", arg0, arg1) + ret0, _ := ret[0].(op.AuthRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AuthRequestByID indicates an expected call of AuthRequestByID +func (mr *MockStorageMockRecorder) AuthRequestByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthRequestByID", reflect.TypeOf((*MockStorage)(nil).AuthRequestByID), arg0, arg1) +} + +// AuthorizeClientIDSecret mocks base method +func (m *MockStorage) AuthorizeClientIDSecret(arg0 context.Context, arg1, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuthorizeClientIDSecret", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// AuthorizeClientIDSecret indicates an expected call of AuthorizeClientIDSecret +func (mr *MockStorageMockRecorder) AuthorizeClientIDSecret(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuthorizeClientIDSecret", reflect.TypeOf((*MockStorage)(nil).AuthorizeClientIDSecret), arg0, arg1, arg2) +} + +// CreateAuthRequest mocks base method +func (m *MockStorage) CreateAuthRequest(arg0 context.Context, arg1 *oidc.AuthRequest) (op.AuthRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateAuthRequest", arg0, arg1) + ret0, _ := ret[0].(op.AuthRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateAuthRequest indicates an expected call of CreateAuthRequest +func (mr *MockStorageMockRecorder) CreateAuthRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockStorage)(nil).CreateAuthRequest), arg0, arg1) +} + +// DeleteAuthRequest mocks base method +func (m *MockStorage) DeleteAuthRequest(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAuthRequest indicates an expected call of DeleteAuthRequest +func (mr *MockStorageMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockStorage)(nil).DeleteAuthRequest), arg0, arg1) +} + +// GetClientByClientID mocks base method +func (m *MockStorage) GetClientByClientID(arg0 context.Context, arg1 string) (op.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientByClientID", arg0, arg1) + ret0, _ := ret[0].(op.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClientByClientID indicates an expected call of GetClientByClientID +func (mr *MockStorageMockRecorder) GetClientByClientID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientByClientID", reflect.TypeOf((*MockStorage)(nil).GetClientByClientID), arg0, arg1) +} + +// GetKeySet mocks base method +func (m *MockStorage) GetKeySet(arg0 context.Context) (*go_jose_v2.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKeySet", arg0) + ret0, _ := ret[0].(*go_jose_v2.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKeySet indicates an expected call of GetKeySet +func (mr *MockStorageMockRecorder) GetKeySet(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockStorage)(nil).GetKeySet), arg0) +} + +// GetSigningKey mocks base method +func (m *MockStorage) GetSigningKey(arg0 context.Context) (*go_jose_v2.SigningKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSigningKey", arg0) + ret0, _ := ret[0].(*go_jose_v2.SigningKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSigningKey indicates an expected call of GetSigningKey +func (mr *MockStorageMockRecorder) GetSigningKey(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSigningKey", reflect.TypeOf((*MockStorage)(nil).GetSigningKey), arg0) +} + +// GetUserinfoFromScopes mocks base method +func (m *MockStorage) GetUserinfoFromScopes(arg0 context.Context, arg1 []string) (*oidc.Userinfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserinfoFromScopes", arg0, arg1) + ret0, _ := ret[0].(*oidc.Userinfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserinfoFromScopes indicates an expected call of GetUserinfoFromScopes +func (mr *MockStorageMockRecorder) GetUserinfoFromScopes(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserinfoFromScopes", reflect.TypeOf((*MockStorage)(nil).GetUserinfoFromScopes), arg0, arg1) +} + +// SaveKeyPair mocks base method +func (m *MockStorage) SaveKeyPair(arg0 context.Context) (*go_jose_v2.SigningKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveKeyPair", arg0) + ret0, _ := ret[0].(*go_jose_v2.SigningKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SaveKeyPair indicates an expected call of SaveKeyPair +func (mr *MockStorageMockRecorder) SaveKeyPair(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveKeyPair", reflect.TypeOf((*MockStorage)(nil).SaveKeyPair), arg0) +} diff --git a/pkg/op/mock/storage.mock.impl.go b/pkg/op/mock/storage.mock.impl.go new file mode 100644 index 0000000..7cd62b9 --- /dev/null +++ b/pkg/op/mock/storage.mock.impl.go @@ -0,0 +1,142 @@ +package mock + +import ( + "context" + "errors" + "testing" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/golang/mock/gomock" + + "github.com/caos/oidc/pkg/op" +) + +func NewStorage(t *testing.T) op.Storage { + return NewMockStorage(gomock.NewController(t)) +} + +func NewMockStorageExpectValidClientID(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectValidClientID(m) + return m +} + +func NewMockStorageExpectInvalidClientID(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectInvalidClientID(m) + return m +} + +func NewMockStorageAny(t *testing.T) op.Storage { + m := NewStorage(t) + mockS := m.(*MockStorage) + mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).AnyTimes().Return(&ConfClient{}, nil) + mockS.EXPECT().AuthorizeClientIDSecret(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(nil) + return m +} + +func NewMockStorageSigningKeyError(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKeyError(m) + return m +} + +func NewMockStorageSigningKeyInvalid(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKeyInvalid(m) + return m +} +func NewMockStorageSigningKey(t *testing.T) op.Storage { + m := NewStorage(t) + ExpectSigningKey(m) + return m +} + +func ExpectInvalidClientID(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).Return(nil, errors.New("client not found")) +} + +func ExpectValidClientID(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetClientByClientID(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, id string) (op.Client, error) { + var appType op.ApplicationType + var authMethod op.AuthMethod + var accessTokenType op.AccessTokenType + switch id { + case "web_client": + appType = op.ApplicationTypeWeb + authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeBearer + case "native_client": + appType = op.ApplicationTypeNative + authMethod = op.AuthMethodNone + accessTokenType = op.AccessTokenTypeBearer + case "useragent_client": + appType = op.ApplicationTypeUserAgent + authMethod = op.AuthMethodBasic + accessTokenType = op.AccessTokenTypeJWT + } + return &ConfClient{id: id, appType: appType, authMethod: authMethod, accessTokenType: accessTokenType}, nil + }) +} + +func ExpectSigningKeyError(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey(gomock.Any()).Return(nil, errors.New("error")) +} + +func ExpectSigningKeyInvalid(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{}, nil) +} + +func ExpectSigningKey(s op.Storage) { + mockS := s.(*MockStorage) + mockS.EXPECT().GetSigningKey(gomock.Any()).Return(&jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, nil) +} + +type ConfClient struct { + id string + appType op.ApplicationType + authMethod op.AuthMethod + accessTokenType op.AccessTokenType +} + +func (c *ConfClient) RedirectURIs() []string { + return []string{ + "https://registered.com/callback", + "http://registered.com/callback", + "http://localhost:9999/callback", + "custom://callback", + } +} + +func (c *ConfClient) LoginURL(id string) string { + return "login?id=" + id +} + +func (c *ConfClient) ApplicationType() op.ApplicationType { + return c.appType +} + +func (c *ConfClient) GetAuthMethod() op.AuthMethod { + return c.authMethod +} + +func (c *ConfClient) GetID() string { + return c.id +} + +func (c *ConfClient) AccessTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) IDTokenLifetime() time.Duration { + return time.Duration(5 * time.Minute) +} +func (c *ConfClient) AccessTokenType() op.AccessTokenType { + return c.accessTokenType +} diff --git a/pkg/op/op.go b/pkg/op/op.go new file mode 100644 index 0000000..7db2ff4 --- /dev/null +++ b/pkg/op/op.go @@ -0,0 +1,51 @@ +package op + +import ( + "context" + "net/http" + + "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + + "github.com/caos/oidc/pkg/oidc" +) + +type OpenIDProvider interface { + Configuration + HandleDiscovery(w http.ResponseWriter, r *http.Request) + HandleAuthorize(w http.ResponseWriter, r *http.Request) + HandleAuthorizeCallback(w http.ResponseWriter, r *http.Request) + HandleExchange(w http.ResponseWriter, r *http.Request) + HandleUserinfo(w http.ResponseWriter, r *http.Request) + HandleKeys(w http.ResponseWriter, r *http.Request) + HttpHandler() *http.Server +} + +func CreateRouter(o OpenIDProvider) *mux.Router { + router := mux.NewRouter() + router.HandleFunc(oidc.DiscoveryEndpoint, o.HandleDiscovery) + router.HandleFunc(o.AuthorizationEndpoint().Relative(), o.HandleAuthorize) + router.HandleFunc(o.AuthorizationEndpoint().Relative()+"/{id}", o.HandleAuthorizeCallback) + router.HandleFunc(o.TokenEndpoint().Relative(), o.HandleExchange) + router.HandleFunc(o.UserinfoEndpoint().Relative(), o.HandleUserinfo) + router.HandleFunc(o.KeysEndpoint().Relative(), o.HandleKeys) + return router +} + +func Start(ctx context.Context, o OpenIDProvider) { + go func() { + <-ctx.Done() + err := o.HttpHandler().Shutdown(ctx) + if err != nil { + logrus.Error("graceful shutdown of oidc server failed") + } + }() + + go func() { + err := o.HttpHandler().ListenAndServe() + if err != nil { + logrus.Panicf("oidc server serve failed: %v", err) + } + }() + logrus.Infof("oidc server is listening on %s", o.Port()) +} diff --git a/pkg/op/session.go b/pkg/op/session.go new file mode 100644 index 0000000..5e19040 --- /dev/null +++ b/pkg/op/session.go @@ -0,0 +1,13 @@ +package op + +import "github.com/caos/oidc/pkg/oidc" + +func NeedsExistingSession(authRequest *oidc.AuthRequest) bool { + if authRequest == nil { + return true + } + if authRequest.Prompt == oidc.PromptNone { + return true + } + return false +} diff --git a/pkg/op/signer.go b/pkg/op/signer.go new file mode 100644 index 0000000..6235931 --- /dev/null +++ b/pkg/op/signer.go @@ -0,0 +1,78 @@ +package op + +import ( + "encoding/json" + + "golang.org/x/net/context" + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" +) + +type Signer interface { + SignIDToken(claims *oidc.IDTokenClaims) (string, error) + SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) + SignatureAlgorithm() jose.SignatureAlgorithm +} + +type idTokenSigner struct { + signer jose.Signer + storage AuthStorage + algorithm jose.SignatureAlgorithm +} + +func NewDefaultSigner(ctx context.Context, storage AuthStorage) (Signer, error) { + s := &idTokenSigner{ + storage: storage, + } + if err := s.initialize(ctx); err != nil { + return nil, err + } + return s, nil +} + +func (s *idTokenSigner) initialize(ctx context.Context) error { + var key *jose.SigningKey + var err error + key, err = s.storage.GetSigningKey(ctx) + if err != nil { + key, err = s.storage.SaveKeyPair(ctx) + if err != nil { + return err + } + } + s.signer, err = jose.NewSigner(*key, &jose.SignerOptions{}) + if err != nil { + return err + } + s.algorithm = key.Algorithm + return nil +} + +func (s *idTokenSigner) SignIDToken(claims *oidc.IDTokenClaims) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + return s.Sign(payload) +} + +func (s *idTokenSigner) SignAccessToken(claims *oidc.AccessTokenClaims) (string, error) { + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + return s.Sign(payload) +} + +func (s *idTokenSigner) Sign(payload []byte) (string, error) { + result, err := s.signer.Sign(payload) + if err != nil { + return "", err + } + return result.CompactSerialize() +} + +func (s *idTokenSigner) SignatureAlgorithm() jose.SignatureAlgorithm { + return s.algorithm +} diff --git a/pkg/op/signer_test.go b/pkg/op/signer_test.go new file mode 100644 index 0000000..21aab0d --- /dev/null +++ b/pkg/op/signer_test.go @@ -0,0 +1,95 @@ +package op + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" +) + +// func TestNewDefaultSigner(t *testing.T) { +// type args struct { +// storage Storage +// } +// tests := []struct { +// name string +// args args +// want Signer +// wantErr bool +// }{ +// { +// "err initialize storage fails", +// args{mock.NewMockStorageSigningKeyError(t)}, +// nil, +// true, +// }, +// { +// "err initialize storage fails", +// args{mock.NewMockStorageSigningKeyInvalid(t)}, +// nil, +// true, +// }, +// { +// "initialize ok", +// args{mock.NewMockStorageSigningKey(t)}, +// &idTokenSigner{Storage: mock.NewMockStorageSigningKey(t)}, +// false, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// got, err := op.NewDefaultSigner(tt.args.storage) +// if (err != nil) != tt.wantErr { +// t.Errorf("NewDefaultSigner() error = %v, wantErr %v", err, tt.wantErr) +// return +// } +// if !reflect.DeepEqual(got, tt.want) { +// t.Errorf("NewDefaultSigner() = %v, want %v", got, tt.want) +// } +// }) +// } +// } + +func Test_idTokenSigner_Sign(t *testing.T) { + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: []byte("key")}, &jose.SignerOptions{}) + require.NoError(t, err) + + type fields struct { + signer jose.Signer + storage Storage + } + type args struct { + payload []byte + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + "ok", + fields{signer, nil}, + args{[]byte("test")}, + "eyJhbGciOiJIUzI1NiJ9.dGVzdA.SxYZRsvB_Dr4F7SEFuYXvkMZqCCwzpsPOQXl-vLPEww", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &idTokenSigner{ + signer: tt.fields.signer, + storage: tt.fields.storage, + } + got, err := s.Sign(tt.args.payload) + if (err != nil) != tt.wantErr { + t.Errorf("idTokenSigner.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("idTokenSigner.Sign() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/op/storage.go b/pkg/op/storage.go new file mode 100644 index 0000000..83b9f3e --- /dev/null +++ b/pkg/op/storage.go @@ -0,0 +1,48 @@ +package op + +import ( + "context" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" +) + +type AuthStorage interface { + CreateAuthRequest(context.Context, *oidc.AuthRequest) (AuthRequest, error) + AuthRequestByID(context.Context, string) (AuthRequest, error) + DeleteAuthRequest(context.Context, string) error + + GetSigningKey(context.Context) (*jose.SigningKey, error) + GetKeySet(context.Context) (*jose.JSONWebKeySet, error) + SaveKeyPair(context.Context) (*jose.SigningKey, error) +} + +type OPStorage interface { + GetClientByClientID(context.Context, string) (Client, error) + AuthorizeClientIDSecret(context.Context, string, string) error + GetUserinfoFromScopes(context.Context, []string) (*oidc.Userinfo, error) +} + +type Storage interface { + AuthStorage + OPStorage +} + +type AuthRequest interface { + GetID() string + GetACR() string + GetAMR() []string + GetAudience() []string + GetAuthTime() time.Time + GetClientID() string + GetCodeChallenge() *oidc.CodeChallenge + GetNonce() string + GetRedirectURI() string + GetResponseType() oidc.ResponseType + GetScopes() []string + GetState() string + GetSubject() string + Done() bool +} diff --git a/pkg/op/token.go b/pkg/op/token.go new file mode 100644 index 0000000..7c1dedc --- /dev/null +++ b/pkg/op/token.go @@ -0,0 +1,93 @@ +package op + +import ( + "time" + + "github.com/caos/oidc/pkg/oidc" +) + +type TokenCreator interface { + Issuer() string + Signer() Signer + Storage() Storage + Crypto() Crypto +} + +func CreateTokenResponse(authReq AuthRequest, client Client, creator TokenCreator, createAccessToken bool, code string) (*oidc.AccessTokenResponse, error) { + var accessToken string + if createAccessToken { + var err error + accessToken, err = CreateAccessToken(authReq, client, creator) + if err != nil { + return nil, err + } + } + idToken, err := CreateIDToken(creator.Issuer(), authReq, client.IDTokenLifetime(), accessToken, code, creator.Signer()) + if err != nil { + return nil, err + } + exp := uint64(client.AccessTokenLifetime().Seconds()) + return &oidc.AccessTokenResponse{ + AccessToken: accessToken, + IDToken: idToken, + TokenType: oidc.BearerToken, + ExpiresIn: exp, + }, nil +} + +func CreateAccessToken(authReq AuthRequest, client Client, creator TokenCreator) (string, error) { + if client.AccessTokenType() == AccessTokenTypeJWT { + return CreateJWT(creator.Issuer(), authReq, client, creator.Signer()) + } + return CreateBearerToken(authReq, creator.Crypto()) +} + +func CreateBearerToken(authReq AuthRequest, crypto Crypto) (string, error) { + return crypto.Encrypt(authReq.GetID()) +} + +func CreateJWT(issuer string, authReq AuthRequest, client Client, signer Signer) (string, error) { + now := time.Now().UTC() + nbf := now + exp := now.Add(client.AccessTokenLifetime()) + claims := &oidc.AccessTokenClaims{ + Issuer: issuer, + Subject: authReq.GetSubject(), + Audiences: authReq.GetAudience(), + Expiration: exp, + IssuedAt: now, + NotBefore: nbf, + } + return signer.SignAccessToken(claims) +} + +func CreateIDToken(issuer string, authReq AuthRequest, validity time.Duration, accessToken, code string, signer Signer) (string, error) { + var err error + exp := time.Now().UTC().Add(validity) + claims := &oidc.IDTokenClaims{ + Issuer: issuer, + Subject: authReq.GetSubject(), + Audiences: authReq.GetAudience(), + Expiration: exp, + IssuedAt: time.Now().UTC(), + AuthTime: authReq.GetAuthTime(), + Nonce: authReq.GetNonce(), + AuthenticationContextClassReference: authReq.GetACR(), + AuthenticationMethodsReferences: authReq.GetAMR(), + AuthorizedParty: authReq.GetClientID(), + } + if accessToken != "" { + claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signer.SignatureAlgorithm()) + if err != nil { + return "", err + } + } + if code != "" { + claims.CodeHash, err = oidc.ClaimHash(code, signer.SignatureAlgorithm()) + if err != nil { + return "", err + } + } + + return signer.SignIDToken(claims) +} diff --git a/pkg/op/tokenrequest.go b/pkg/op/tokenrequest.go new file mode 100644 index 0000000..c8a7fe8 --- /dev/null +++ b/pkg/op/tokenrequest.go @@ -0,0 +1,151 @@ +package op + +import ( + "context" + "errors" + "net/http" + + "github.com/gorilla/schema" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" +) + +type Exchanger interface { + Issuer() string + Storage() Storage + Decoder() *schema.Decoder + Signer() Signer + Crypto() Crypto + AuthMethodPostSupported() bool +} + +func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + tokenReq, err := ParseAccessTokenRequest(r, exchanger.Decoder()) + if err != nil { + ExchangeRequestError(w, r, err) + } + if tokenReq.Code == "" { + ExchangeRequestError(w, r, ErrInvalidRequest("code missing")) + return + } + authReq, client, err := ValidateAccessTokenRequest(r.Context(), tokenReq, exchanger) + if err != nil { + ExchangeRequestError(w, r, err) + return + } + err = exchanger.Storage().DeleteAuthRequest(r.Context(), authReq.GetID()) + if err != nil { + ExchangeRequestError(w, r, err) + return + } + resp, err := CreateTokenResponse(authReq, client, exchanger, true, tokenReq.Code) + if err != nil { + ExchangeRequestError(w, r, err) + return + } + utils.MarshalJSON(w, resp) +} + +func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) { + err := r.ParseForm() + if err != nil { + return nil, ErrInvalidRequest("error parsing form") + } + tokenReq := new(oidc.AccessTokenRequest) + err = decoder.Decode(tokenReq, r.Form) + if err != nil { + return nil, ErrInvalidRequest("error decoding form") + } + clientID, clientSecret, ok := r.BasicAuth() + if ok { + tokenReq.ClientID = clientID + tokenReq.ClientSecret = clientSecret + + } + return tokenReq, nil +} + +func ValidateAccessTokenRequest(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { + authReq, client, err := AuthorizeClient(ctx, tokenReq, exchanger) + if err != nil { + return nil, nil, err + } + if client.GetID() != authReq.GetClientID() { + return nil, nil, ErrInvalidRequest("invalid auth code") + } + if tokenReq.RedirectURI != authReq.GetRedirectURI() { + return nil, nil, ErrInvalidRequest("redirect_uri does no correspond") + } + return authReq, client, nil +} + +func AuthorizeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, Client, error) { + client, err := exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID) + if err != nil { + return nil, nil, err + } + if client.GetAuthMethod() == AuthMethodNone { + authReq, err := AuthorizeCodeChallenge(ctx, tokenReq, exchanger) + return authReq, client, err + } + if client.GetAuthMethod() == AuthMethodPost && !exchanger.AuthMethodPostSupported() { + return nil, nil, errors.New("basic not supported") + } + err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()) + if err != nil { + return nil, nil, err + } + authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) + if err != nil { + return nil, nil, ErrInvalidRequest("invalid code") + } + return authReq, client, nil +} + +func AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string, storage OPStorage) error { + return storage.AuthorizeClientIDSecret(ctx, clientID, clientSecret) +} + +func AuthorizeCodeChallenge(ctx context.Context, tokenReq *oidc.AccessTokenRequest, exchanger Exchanger) (AuthRequest, error) { + if tokenReq.CodeVerifier == "" { + return nil, ErrInvalidRequest("code_challenge required") + } + authReq, err := AuthRequestByCode(ctx, tokenReq.Code, exchanger.Crypto(), exchanger.Storage()) + if err != nil { + return nil, ErrInvalidRequest("invalid code") + } + if !oidc.VerifyCodeChallenge(authReq.GetCodeChallenge(), tokenReq.CodeVerifier) { + return nil, ErrInvalidRequest("code_challenge invalid") + } + return authReq, nil +} + +func AuthRequestByCode(ctx context.Context, code string, crypto Crypto, storage AuthStorage) (AuthRequest, error) { + id, err := crypto.Decrypt(code) + if err != nil { + return nil, err + } + return storage.AuthRequestByID(ctx, id) +} + +func TokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) { + tokenRequest, err := ParseTokenExchangeRequest(w, r) + if err != nil { + ExchangeRequestError(w, r, err) + return + } + err = ValidateTokenExchangeRequest(tokenRequest, exchanger.Storage()) + if err != nil { + ExchangeRequestError(w, r, err) + return + } +} + +func ParseTokenExchangeRequest(w http.ResponseWriter, r *http.Request) (oidc.TokenRequest, error) { + return nil, errors.New("Unimplemented") //TODO: impl +} + +func ValidateTokenExchangeRequest(tokenReq oidc.TokenRequest, storage Storage) error { + return errors.New("Unimplemented") //TODO: impl +} diff --git a/pkg/op/userinfo.go b/pkg/op/userinfo.go new file mode 100644 index 0000000..ac47e68 --- /dev/null +++ b/pkg/op/userinfo.go @@ -0,0 +1,28 @@ +package op + +import ( + "net/http" + + "github.com/caos/oidc/pkg/utils" +) + +type UserinfoProvider interface { + Storage() Storage +} + +func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) { + scopes, err := ScopesFromAccessToken(w, r) + if err != nil { + return + } + info, err := userinfoProvider.Storage().GetUserinfoFromScopes(r.Context(), scopes) + if err != nil { + utils.MarshalJSON(w, err) + return + } + utils.MarshalJSON(w, info) +} + +func ScopesFromAccessToken(w http.ResponseWriter, r *http.Request) ([]string, error) { + return []string{}, nil +} diff --git a/pkg/rp/default_rp.go b/pkg/rp/default_rp.go new file mode 100644 index 0000000..1f18984 --- /dev/null +++ b/pkg/rp/default_rp.go @@ -0,0 +1,287 @@ +package rp + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/caos/oidc/pkg/oidc/grants" + + "golang.org/x/oauth2" + + "github.com/caos/oidc/pkg/oidc" + grants_tx "github.com/caos/oidc/pkg/oidc/grants/tokenexchange" + "github.com/caos/oidc/pkg/utils" +) + +const ( + idTokenKey = "id_token" + stateParam = "state" + pkceCode = "pkce" +) + +var ( + DefaultErrorHandler = func(w http.ResponseWriter, r *http.Request, errorType string, errorDesc string, state string) { + http.Error(w, errorType+": "+errorDesc, http.StatusInternalServerError) + } +) + +//DefaultRP impements the `DelegationTokenExchangeRP` interface extending the `RelayingParty` interface +type DefaultRP struct { + endpoints Endpoints + + oauthConfig oauth2.Config + config *Config + pkce bool + + httpClient *http.Client + cookieHandler *utils.CookieHandler + + errorHandler func(http.ResponseWriter, *http.Request, string, string, string) + + verifier Verifier +} + +//NewDefaultRP creates `DefaultRP` with the given +//Config and possible configOptions +//it will run discovery on the provided issuer +//if no verifier is provided using the options the `DefaultVerifier` is set +func NewDefaultRP(rpConfig *Config, rpOpts ...DefaultRPOpts) (DelegationTokenExchangeRP, error) { + p := &DefaultRP{ + config: rpConfig, + httpClient: utils.DefaultHTTPClient, + } + + for _, optFunc := range rpOpts { + optFunc(p) + } + + if err := p.discover(); err != nil { + return nil, err + } + + if p.errorHandler == nil { + p.errorHandler = DefaultErrorHandler + } + + if p.verifier == nil { + p.verifier = NewDefaultVerifier(rpConfig.Issuer, rpConfig.ClientID, NewRemoteKeySet(p.httpClient, p.endpoints.JKWsURL)) + } + + return p, nil +} + +//DefaultRPOpts is the type for providing dynamic options to the DefaultRP +type DefaultRPOpts func(p *DefaultRP) + +//WithCookieHandler set a `CookieHandler` for securing the various redirects +func WithCookieHandler(cookieHandler *utils.CookieHandler) DefaultRPOpts { + return func(p *DefaultRP) { + p.cookieHandler = cookieHandler + } +} + +//WithPKCE sets the RP to use PKCE (oauth2 code challenge) +//it also sets a `CookieHandler` for securing the various redirects +//and exchanging the code challenge +func WithPKCE(cookieHandler *utils.CookieHandler) DefaultRPOpts { + return func(p *DefaultRP) { + p.pkce = true + p.cookieHandler = cookieHandler + } +} + +//WithHTTPClient provides the ability to set an http client to be used for the relaying party and verifier +func WithHTTPClient(client *http.Client) DefaultRPOpts { + return func(p *DefaultRP) { + p.httpClient = client + } +} + +//AuthURL is the `RelayingParty` interface implementation +//wrapping the oauth2 `AuthCodeURL` +//returning the url of the auth request +func (p *DefaultRP) AuthURL(state string, opts ...AuthURLOpt) string { + authOpts := make([]oauth2.AuthCodeOption, 0) + for _, opt := range opts { + authOpts = append(authOpts, opt()...) + } + return p.oauthConfig.AuthCodeURL(state, authOpts...) +} + +//AuthURL is the `RelayingParty` interface implementation +//extending the `AuthURL` method with a http redirect handler +func (p *DefaultRP) AuthURLHandler(state string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + opts := make([]AuthURLOpt, 0) + if err := p.trySetStateCookie(w, state); err != nil { + http.Error(w, "failed to create state cookie: "+err.Error(), http.StatusUnauthorized) + return + } + if p.pkce { + codeChallenge, err := p.generateAndStoreCodeChallenge(w) + if err != nil { + http.Error(w, "failed to create code challenge: "+err.Error(), http.StatusUnauthorized) + return + } + opts = append(opts, WithCodeChallenge(codeChallenge)) + } + http.Redirect(w, r, p.AuthURL(state, opts...), http.StatusFound) + } +} + +func (p *DefaultRP) generateAndStoreCodeChallenge(w http.ResponseWriter) (string, error) { + var codeVerifier string + codeVerifier = "s" + if err := p.cookieHandler.SetCookie(w, pkceCode, codeVerifier); err != nil { + return "", err + } + return oidc.NewSHACodeChallenge(codeVerifier), nil +} + +//AuthURL is the `RelayingParty` interface implementation +//handling the oauth2 code exchange, extracting and validating the id_token +//returning it paresed together with the oauth2 tokens (access, refresh) +func (p *DefaultRP) CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (tokens *oidc.Tokens, err error) { + ctx = context.WithValue(ctx, oauth2.HTTPClient, p.httpClient) + codeOpts := make([]oauth2.AuthCodeOption, 0) + for _, opt := range opts { + codeOpts = append(codeOpts, opt()...) + } + + token, err := p.oauthConfig.Exchange(ctx, code, codeOpts...) + if err != nil { + return nil, err //TODO: our error + } + idTokenString, ok := token.Extra(idTokenKey).(string) + if !ok { + //TODO: implement + } + + idToken, err := p.verifier.Verify(ctx, token.AccessToken, idTokenString) + if err != nil { + return nil, err //TODO: err + } + + return &oidc.Tokens{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil +} + +//AuthURL is the `RelayingParty` interface implementation +//extending the `CodeExchange` method with callback function +func (p *DefaultRP) CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + state, err := p.tryReadStateCookie(w, r) + if err != nil { + http.Error(w, "failed to get state: "+err.Error(), http.StatusUnauthorized) + return + } + params := r.URL.Query() + if params.Get("error") != "" { + p.errorHandler(w, r, params.Get("error"), params.Get("error_description"), state) + return + } + codeOpts := make([]CodeExchangeOpt, 0) + if p.pkce { + codeVerifier, err := p.cookieHandler.CheckCookie(r, pkceCode) + if err != nil { + http.Error(w, "failed to get code verifier: "+err.Error(), http.StatusUnauthorized) + return + } + codeOpts = append(codeOpts, WithCodeVerifier(codeVerifier)) + } + tokens, err := p.CodeExchange(r.Context(), params.Get("code"), codeOpts...) + if err != nil { + http.Error(w, "failed to exchange token: "+err.Error(), http.StatusUnauthorized) + return + } + callback(w, r, tokens, state) + } +} + +// func (p *DefaultRP) Introspect(ctx context.Context, accessToken string) (oidc.TokenIntrospectResponse, error) { +// // req := &http.Request{} +// // resp, err := p.httpClient.Do(req) +// // if err != nil { + +// // } +// // p.endpoints.IntrospectURL +// return nil, nil +// } + +func (p *DefaultRP) Userinfo() {} + +//ClientCredentials is the `RelayingParty` interface implementation +//handling the oauth2 client credentials grant +func (p *DefaultRP) ClientCredentials(ctx context.Context, scopes ...string) (newToken *oauth2.Token, err error) { + return p.callTokenEndpoint(grants.ClientCredentialsGrantBasic(scopes...)) +} + +//TokenExchange is the `TokenExchangeRP` interface implementation +//handling the oauth2 token exchange (draft) +func (p *DefaultRP) TokenExchange(ctx context.Context, request *grants_tx.TokenExchangeRequest) (newToken *oauth2.Token, err error) { + return p.callTokenEndpoint(request) +} + +//DelegationTokenExchange is the `TokenExchangeRP` interface implementation +//handling the oauth2 token exchange for a delegation token (draft) +func (p *DefaultRP) DelegationTokenExchange(ctx context.Context, subjectToken string, reqOpts ...grants_tx.TokenExchangeOption) (newToken *oauth2.Token, err error) { + return p.TokenExchange(ctx, DelegationTokenRequest(subjectToken, reqOpts...)) +} + +func (p *DefaultRP) discover() error { + wellKnown := strings.TrimSuffix(p.config.Issuer, "/") + oidc.DiscoveryEndpoint + req, err := http.NewRequest("GET", wellKnown, nil) + if err != nil { + return err + } + discoveryConfig := new(oidc.DiscoveryConfiguration) + err = utils.HttpRequest(p.httpClient, req, &discoveryConfig) + if err != nil { + return err + } + p.endpoints = GetEndpoints(discoveryConfig) + p.oauthConfig = oauth2.Config{ + ClientID: p.config.ClientID, + ClientSecret: p.config.ClientSecret, + Endpoint: p.endpoints.Endpoint, + RedirectURL: p.config.CallbackURL, + Scopes: p.config.Scopes, + } + return nil +} + +func (p *DefaultRP) callTokenEndpoint(request interface{}) (newToken *oauth2.Token, err error) { + req, err := utils.FormRequest(p.endpoints.TokenURL, request) + if err != nil { + return nil, err + } + auth := base64.StdEncoding.EncodeToString([]byte(p.config.ClientID + ":" + p.config.ClientSecret)) + req.Header.Set("Authorization", "Basic "+auth) + token := new(oauth2.Token) + if err := utils.HttpRequest(p.httpClient, req, token); err != nil { + return nil, err + } + return token, nil +} + +func (p *DefaultRP) trySetStateCookie(w http.ResponseWriter, state string) error { + if p.cookieHandler != nil { + if err := p.cookieHandler.SetCookie(w, stateParam, state); err != nil { + return err + } + } + return nil +} + +func (p *DefaultRP) tryReadStateCookie(w http.ResponseWriter, r *http.Request) (state string, err error) { + if p.cookieHandler == nil { + return r.FormValue(stateParam), nil + } + state, err = p.cookieHandler.CheckQueryCookie(r, stateParam) + if err != nil { + return "", err + } + p.cookieHandler.DeleteCookie(w, stateParam) + return state, nil +} diff --git a/pkg/rp/default_verifier.go b/pkg/rp/default_verifier.go new file mode 100644 index 0000000..58adddb --- /dev/null +++ b/pkg/rp/default_verifier.go @@ -0,0 +1,363 @@ +package rp + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" + "github.com/caos/oidc/pkg/utils" +) + +//DefaultVerifier implements the `Verifier` interface +type DefaultVerifier struct { + config *verifierConfig + keySet oidc.KeySet +} + +//ConfFunc is the type for providing dynamic options to the DefaultVerfifier +type ConfFunc func(*verifierConfig) + +//ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim +type ACRVerifier func(string) error + +//NewDefaultVerifier creates `DefaultVerifier` with the given +//issuer, clientID, keyset and possible configOptions +func NewDefaultVerifier(issuer, clientID string, keySet oidc.KeySet, confOpts ...ConfFunc) Verifier { + conf := &verifierConfig{ + issuer: issuer, + clientID: clientID, + iat: &iatConfig{ + // offset: time.Duration(500 * time.Millisecond), + }, + } + + for _, opt := range confOpts { + if opt != nil { + opt(conf) + } + } + return &DefaultVerifier{config: conf, keySet: keySet} +} + +//WithIgnoreIssuedAt will turn off iat claim verification +func WithIgnoreIssuedAt() func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.iat.ignore = true + } +} + +//WithIssuedAtOffset mitigates the risk of iat to be in the future +//because of clock skews with the ability to add an offset to the current time +func WithIssuedAtOffset(offset time.Duration) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.iat.offset = offset + } +} + +//WithIssuedAtMaxAge provides the ability to define the maximum duration between iat and now +func WithIssuedAtMaxAge(maxAge time.Duration) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.iat.maxAge = maxAge + } +} + +//WithNonce TODO: ? +func WithNonce(nonce string) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.nonce = nonce + } +} + +//WithACRVerifier sets the verifier for the acr claim +func WithACRVerifier(verifier ACRVerifier) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.acr = verifier + } +} + +//WithAuthTimeMaxAge provides the ability to define the maximum duration between auth_time and now +func WithAuthTimeMaxAge(maxAge time.Duration) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.maxAge = maxAge + } +} + +//WithSupportedSigningAlgorithms overwrites the default RS256 signing algorithm +func WithSupportedSigningAlgorithms(algs ...string) func(*verifierConfig) { + return func(conf *verifierConfig) { + conf.supportedSignAlgs = algs + } +} + +type verifierConfig struct { + issuer string + clientID string + nonce string + iat *iatConfig + acr ACRVerifier + maxAge time.Duration + supportedSignAlgs []string + + // httpClient *http.Client + + now time.Time +} + +type iatConfig struct { + ignore bool + offset time.Duration + maxAge time.Duration +} + +//DefaultACRVerifier implements `ACRVerifier` returning an error +//if non of the provided values matches the acr claim +func DefaultACRVerifier(possibleValues []string) ACRVerifier { + return func(acr string) error { + if !utils.Contains(possibleValues, acr) { + return ErrAcrInvalid(possibleValues, acr) + } + return nil + } +} + +//Verify implements the `Verify` method of the `Verifier` interface +//according to https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +//and https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowTokenValidation +func (v *DefaultVerifier) Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) { + v.config.now = time.Now().UTC() + idToken, err := v.VerifyIDToken(ctx, idTokenString) + if err != nil { + return nil, err + } + if err := v.verifyAccessToken(accessToken, idToken.AccessTokenHash, idToken.Signature); err != nil { //TODO: sig from token + return nil, err + } + return idToken, nil +} + +func (v *DefaultVerifier) now() time.Time { + if v.config.now.IsZero() { + v.config.now = time.Now().UTC().Round(time.Second) + } + return v.config.now +} + +//VerifyIDToken: https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +func (v *DefaultVerifier) VerifyIDToken(ctx context.Context, idTokenString string) (*oidc.IDTokenClaims, error) { + //1. if encrypted --> decrypt + decrypted, err := v.decryptToken(idTokenString) + if err != nil { + return nil, err + } + claims, payload, err := v.parseToken(decrypted) + if err != nil { + return nil, err + } + // token, err := jwt.ParseWithClaims(decrypted, claims, func(token *jwt.Token) (interface{}, error) { + //2, check issuer (exact match) + if err := v.checkIssuer(claims.Issuer); err != nil { + return nil, err + } + + //3. check aud (aud must contain client_id, all aud strings must be allowed) + if err = v.checkAudience(claims.Audiences); err != nil { + return nil, err + } + + if err = v.checkAuthorizedParty(claims.Audiences, claims.AuthorizedParty); err != nil { + return nil, err + } + + //6. check signature by keys + //7. check alg default is rs256 + //8. check if alg is mac based (hs...) -> audience contains client_id. for validation use utf-8 representation of your client_secret + claims.Signature, err = v.checkSignature(ctx, decrypted, payload) + if err != nil { + return nil, err + } + + //9. check exp before now + if err = v.checkExpiration(claims.Expiration); err != nil { + return nil, err + } + + //10. check iat duration is optional (can be checked) + if err = v.checkIssuedAt(claims.IssuedAt); err != nil { + return nil, err + } + + //11. check nonce (check if optional possible) id_token.nonce == sentNonce + if err = v.checkNonce(claims.Nonce); err != nil { + return nil, err + } + + //12. if acr requested check acr + if err = v.checkAuthorizationContextClassReference(claims.AuthenticationContextClassReference); err != nil { + return nil, err + } + + //13. if auth_time requested check if auth_time is less than max age + if err = v.checkAuthTime(claims.AuthTime); err != nil { + return nil, err + } + + return claims, nil +} + +func (v *DefaultVerifier) parseToken(tokenString string) (*oidc.IDTokenClaims, []byte, error) { + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, nil, ValidationError("token contains an invalid number of segments") //TODO: err NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) + } + idToken := new(oidc.IDTokenClaims) + err = json.Unmarshal(payload, idToken) + return idToken, payload, err +} + +func (v *DefaultVerifier) checkIssuer(issuer string) error { + if v.config.issuer != issuer { + return ErrIssuerInvalid(v.config.issuer, issuer) + } + return nil +} + +func (v *DefaultVerifier) checkAudience(audiences []string) error { + if !utils.Contains(audiences, v.config.clientID) { + return ErrAudienceMissingClientID(v.config.clientID) + } + + //TODO: check aud trusted + return nil +} + +//4. if multiple aud strings --> check if azp +//5. if azp --> check azp == client_id +func (v *DefaultVerifier) checkAuthorizedParty(audiences []string, authorizedParty string) error { + if len(audiences) > 1 { + if authorizedParty == "" { + return ErrAzpMissing() + } + } + if authorizedParty != "" && authorizedParty != v.config.clientID { + return ErrAzpInvalid(authorizedParty, v.config.clientID) + } + return nil +} + +func (v *DefaultVerifier) checkSignature(ctx context.Context, idTokenString string, payload []byte) (jose.SignatureAlgorithm, error) { + jws, err := jose.ParseSigned(idTokenString) + if err != nil { + return "", err + } + if len(jws.Signatures) == 0 { + return "", nil //TODO: error + } + if len(jws.Signatures) > 1 { + return "", nil //TODO: error + } + sig := jws.Signatures[0] + supportedSigAlgs := v.config.supportedSignAlgs + if len(supportedSigAlgs) == 0 { + supportedSigAlgs = []string{"RS256"} + } + if !utils.Contains(supportedSigAlgs, sig.Header.Algorithm) { + return "", fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm) + } + + signedPayload, err := v.keySet.VerifySignature(ctx, jws) + if err != nil { + return "", err + //TODO: + } + + if !bytes.Equal(signedPayload, payload) { + return "", ErrSignatureInvalidPayload() //TODO: err + } + return jose.SignatureAlgorithm(sig.Header.Algorithm), nil +} + +func (v *DefaultVerifier) checkExpiration(expiration time.Time) error { + expiration = expiration.Round(time.Second) + if !v.now().Before(expiration) { + return ErrExpInvalid(expiration) + } + return nil +} + +func (v *DefaultVerifier) checkIssuedAt(issuedAt time.Time) error { + if v.config.iat.ignore { + return nil + } + issuedAt = issuedAt.Round(time.Second) + offset := v.now().Add(v.config.iat.offset).Round(time.Second) + if issuedAt.After(offset) { + return ErrIatInFuture(issuedAt, offset) + } + if v.config.iat.maxAge == 0 { + return nil + } + maxAge := v.now().Add(-v.config.iat.maxAge).Round(time.Second) + if issuedAt.Before(maxAge) { + return ErrIatToOld(maxAge, issuedAt) + } + return nil +} +func (v *DefaultVerifier) checkNonce(nonce string) error { + if v.config.nonce == "" { + return nil + } + if v.config.nonce != nonce { + return ErrNonceInvalid(v.config.nonce, nonce) + } + return nil +} +func (v *DefaultVerifier) checkAuthorizationContextClassReference(acr string) error { + if v.config.acr != nil { + return v.config.acr(acr) + } + return nil +} +func (v *DefaultVerifier) checkAuthTime(authTime time.Time) error { + if v.config.maxAge == 0 { + return nil + } + if authTime.IsZero() { + return ErrAuthTimeNotPresent() + } + authTime = authTime.Round(time.Second) + maxAge := v.now().Add(-v.config.maxAge).Round(time.Second) + if authTime.Before(maxAge) { + return ErrAuthTimeToOld(maxAge, authTime) + } + return nil +} + +func (v *DefaultVerifier) decryptToken(tokenString string) (string, error) { + return tokenString, nil //TODO: impl +} + +func (v *DefaultVerifier) verifyAccessToken(accessToken, atHash string, sigAlgorithm jose.SignatureAlgorithm) error { + if atHash == "" { + return nil //TODO: return error + } + + actual, err := oidc.ClaimHash(accessToken, sigAlgorithm) + if err != nil { + return err + } + if actual != atHash { + return nil //TODO: error + } + return nil +} diff --git a/pkg/rp/delegation.go b/pkg/rp/delegation.go new file mode 100644 index 0000000..3ae6bb6 --- /dev/null +++ b/pkg/rp/delegation.go @@ -0,0 +1,13 @@ +package rp + +import ( + "github.com/caos/oidc/pkg/oidc/grants/tokenexchange" +) + +//DelegationTokenRequest is an implementation of TokenExchangeRequest +//it exchanges a "urn:ietf:params:oauth:token-type:access_token" with an optional +//"urn:ietf:params:oauth:token-type:access_token" actor token for a +//"urn:ietf:params:oauth:token-type:access_token" delegation token +func DelegationTokenRequest(subjectToken string, opts ...tokenexchange.TokenExchangeOption) *tokenexchange.TokenExchangeRequest { + return tokenexchange.NewTokenExchangeRequest(subjectToken, tokenexchange.AccessTokenType, opts...) +} diff --git a/pkg/rp/error.go b/pkg/rp/error.go new file mode 100644 index 0000000..038aa4a --- /dev/null +++ b/pkg/rp/error.go @@ -0,0 +1,58 @@ +package rp + +import ( + "fmt" + "time" +) + +var ( + ErrIssuerInvalid = func(expected, actual string) *validationError { + return ValidationError("Issuer does not match. Expected: %s, got: %s", expected, actual) + } + ErrAudienceMissingClientID = func(clientID string) *validationError { + return ValidationError("Audience is not valid. Audience must contain client_id (%s)", clientID) + } + ErrAzpMissing = func() *validationError { + return ValidationError("Authorized Party is not set. If Token is valid for multiple audiences, azp must not be empty") + } + ErrAzpInvalid = func(azp, clientID string) *validationError { + return ValidationError("Authorized Party is not valid. azp (%s) must be equal to client_id (%s)", azp, clientID) + } + ErrExpInvalid = func(exp time.Time) *validationError { + return ValidationError("Token has expired %v", exp) + } + ErrIatInFuture = func(exp, now time.Time) *validationError { + return ValidationError("IssuedAt of token is in the future (%v, now with offset: %v)", exp, now) + } + ErrIatToOld = func(maxAge, iat time.Time) *validationError { + return ValidationError("IssuedAt of token must not be older than %v, but was %v (%v to old)", maxAge, iat, maxAge.Sub(iat)) + } + ErrNonceInvalid = func(expected, actual string) *validationError { + return ValidationError("nonce does not match. Expected: %s, got: %s", expected, actual) + } + ErrAcrInvalid = func(expected []string, actual string) *validationError { + return ValidationError("acr is invalid. Expected one of: %v, got: %s", expected, actual) + } + + ErrAuthTimeNotPresent = func() *validationError { + return ValidationError("claim `auth_time` of token is missing") + } + ErrAuthTimeToOld = func(maxAge, authTime time.Time) *validationError { + return ValidationError("Auth Time of token must not be older than %v, but was %v (%v to old)", maxAge, authTime, maxAge.Sub(authTime)) + } + ErrSignatureInvalidPayload = func() *validationError { + return ValidationError("Signature does not match Payload") + } +) + +func ValidationError(message string, args ...interface{}) *validationError { + return &validationError{fmt.Sprintf(message, args...)} //TODO: impl +} + +type validationError struct { + message string +} + +func (v *validationError) Error() string { + return v.message +} diff --git a/pkg/rp/jwks.go b/pkg/rp/jwks.go new file mode 100644 index 0000000..45ab9f4 --- /dev/null +++ b/pkg/rp/jwks.go @@ -0,0 +1,166 @@ +package rp + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + + "github.com/caos/oidc/pkg/utils" + + "gopkg.in/square/go-jose.v2" + + "github.com/caos/oidc/pkg/oidc" +) + +func NewRemoteKeySet(client *http.Client, jwksURL string) oidc.KeySet { + return &remoteKeySet{httpClient: client, jwksURL: jwksURL} +} + +type remoteKeySet struct { + jwksURL string + httpClient *http.Client + + // guard all other fields + mu sync.Mutex + + // inflight suppresses parallel execution of updateKeys and allows + // multiple goroutines to wait for its result. + inflight *inflight + + // A set of cached keys and their expiry. + cachedKeys []jose.JSONWebKey +} + +// inflight is used to wait on some in-flight request from multiple goroutines. +type inflight struct { + doneCh chan struct{} + + keys []jose.JSONWebKey + err error +} + +func newInflight() *inflight { + return &inflight{doneCh: make(chan struct{})} +} + +// wait returns a channel that multiple goroutines can receive on. Once it returns +// a value, the inflight request is done and result() can be inspected. +func (i *inflight) wait() <-chan struct{} { + return i.doneCh +} + +// done can only be called by a single goroutine. It records the result of the +// inflight request and signals other goroutines that the result is safe to +// inspect. +func (i *inflight) done(keys []jose.JSONWebKey, err error) { + i.keys = keys + i.err = err + close(i.doneCh) +} + +// result cannot be called until the wait() channel has returned a value. +func (i *inflight) result() ([]jose.JSONWebKey, error) { + return i.keys, i.err +} + +func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) { + // We don't support JWTs signed with multiple signatures. + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + + keys := r.keysFromCache() + payload, err, ok := checkKey(keyID, keys, jws) + if ok { + return payload, err + } + + keys, err = r.keysFromRemote(ctx) + if err != nil { + return nil, fmt.Errorf("fetching keys %v", err) + } + + payload, err, ok = checkKey(keyID, keys, jws) + if !ok { + return nil, errors.New("invalid kid") + } + return payload, err +} + +func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { + r.mu.Lock() + defer r.mu.Unlock() + return r.cachedKeys +} + +// keysFromRemote syncs the key set from the remote set, records the values in the +// cache, and returns the key set. +func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { + // Need to lock to inspect the inflight request field. + r.mu.Lock() + // If there's not a current inflight request, create one. + if r.inflight == nil { + r.inflight = newInflight() + + // This goroutine has exclusive ownership over the current inflight + // request. It releases the resource by nil'ing the inflight field + // once the goroutine is done. + go r.updateKeys(ctx) + } + inflight := r.inflight + r.mu.Unlock() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-inflight.wait(): + return inflight.result() + } +} + +func (r *remoteKeySet) updateKeys(ctx context.Context) { + // Sync keys and finish inflight when that's done. + keys, err := r.fetchRemoteKeys(ctx) + + r.inflight.done(keys, err) + + // Lock to update the keys and indicate that there is no longer an + // inflight request. + r.mu.Lock() + defer r.mu.Unlock() + + if err == nil { + r.cachedKeys = keys + } + + // Free inflight so a different request can run. + r.inflight = nil +} + +func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) { + req, err := http.NewRequest("GET", r.jwksURL, nil) + if err != nil { + return nil, fmt.Errorf("oidc: can't create request: %v", err) + } + + keySet := new(jose.JSONWebKeySet) + if err = utils.HttpRequest(r.httpClient, req, keySet); err != nil { + return nil, fmt.Errorf("oidc: failed to get keys: %v", err) + } + + return keySet.Keys, nil +} + +func checkKey(keyID string, keys []jose.JSONWebKey, jws *jose.JSONWebSignature) ([]byte, error, bool) { + for _, key := range keys { + if keyID == "" || key.KeyID == keyID { + payload, err := jws.Verify(&key) + return payload, err, true + } + } + return nil, nil, false +} diff --git a/pkg/rp/relaying_party.go b/pkg/rp/relaying_party.go new file mode 100644 index 0000000..d706839 --- /dev/null +++ b/pkg/rp/relaying_party.go @@ -0,0 +1,105 @@ +package rp + +import ( + "context" + "net/http" + + "github.com/caos/oidc/pkg/oidc" + + "golang.org/x/oauth2" +) + +//RelayingParty declares the minimal interface for oidc clients +type RelayingParty interface { + + //AuthURL returns the authorization endpoint with a given state + AuthURL(state string, opts ...AuthURLOpt) string + + //AuthURLHandler should implement the AuthURL func as http.HandlerFunc + //(redirecting to the auth endpoint) + AuthURLHandler(state string) http.HandlerFunc + + //CodeExchange implements the OIDC Token Request (oauth2 Authorization Code Grant) + //returning an `Access Token` and `ID Token Claims` + CodeExchange(ctx context.Context, code string, opts ...CodeExchangeOpt) (*oidc.Tokens, error) + + //CodeExchangeHandler extends the CodeExchange func, + //calling the provided callback func on success with additional returned `state` + CodeExchangeHandler(callback func(http.ResponseWriter, *http.Request, *oidc.Tokens, string)) http.HandlerFunc + + //ClientCredentials implements the oauth2 Client Credentials Grant + //requesting an `Access Token` for the client itself, without user context + ClientCredentials(ctx context.Context, scopes ...string) (*oauth2.Token, error) + + //Introspects calls the Introspect Endpoint + //for validating an (access) token + // Introspect(ctx context.Context, token string) (TokenIntrospectResponse, error) + + //Userinfo implements the OIDC Userinfo call + //returning the info of the user for the requested scopes of an access token + Userinfo() +} + +//PasswortGrantRP extends the `RelayingParty` interface with the oauth2 `Password Grant` +// +//This interface is separated from the standard `RelayingParty` interface as the `password grant` +//is part of the oauth2 and therefore OIDC specification, but should only be used when there's no +//other possibility, so IMHO never ever. Ever. +type PasswortGrantRP interface { + RelayingParty + + //PasswordGrant implements the oauth2 `Password Grant`, + //requesting an access token with the users `username` and `password` + PasswordGrant(context.Context, string, string) (*oauth2.Token, error) +} + +type Config struct { + ClientID string + ClientSecret string + CallbackURL string + Issuer string + Scopes []string +} + +type OptionFunc func(RelayingParty) + +type Endpoints struct { + oauth2.Endpoint + IntrospectURL string + UserinfoURL string + JKWsURL string +} + +func GetEndpoints(discoveryConfig *oidc.DiscoveryConfiguration) Endpoints { + return Endpoints{ + Endpoint: oauth2.Endpoint{ + AuthURL: discoveryConfig.AuthorizationEndpoint, + AuthStyle: oauth2.AuthStyleAutoDetect, + TokenURL: discoveryConfig.TokenEndpoint, + }, + IntrospectURL: discoveryConfig.IntrospectionEndpoint, + UserinfoURL: discoveryConfig.UserinfoEndpoint, + JKWsURL: discoveryConfig.JwksURI, + } +} + +type AuthURLOpt func() []oauth2.AuthCodeOption + +//WithCodeChallenge sets the `code_challenge` params in the auth request +func WithCodeChallenge(codeChallenge string) AuthURLOpt { + return func() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", codeChallenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + } + } +} + +type CodeExchangeOpt func() []oauth2.AuthCodeOption + +//WithCodeVerifier sets the `code_verifier` param in the token request +func WithCodeVerifier(codeVerifier string) CodeExchangeOpt { + return func() []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("code_verifier", codeVerifier)} + } +} diff --git a/pkg/rp/tockenexchange.go b/pkg/rp/tockenexchange.go new file mode 100644 index 0000000..d84b38e --- /dev/null +++ b/pkg/rp/tockenexchange.go @@ -0,0 +1,27 @@ +package rp + +import ( + "context" + + "golang.org/x/oauth2" + + "github.com/caos/oidc/pkg/oidc/grants/tokenexchange" +) + +//TokenExchangeRP extends the `RelayingParty` interface for the *draft* oauth2 `Token Exchange` +type TokenExchangeRP interface { + RelayingParty + + //TokenExchange implement the `Token Echange Grant` exchanging some token for an other + TokenExchange(context.Context, *tokenexchange.TokenExchangeRequest) (*oauth2.Token, error) +} + +//DelegationTokenExchangeRP extends the `TokenExchangeRP` interface +//for the specific `delegation token` request +type DelegationTokenExchangeRP interface { + TokenExchangeRP + + //DelegationTokenExchange implement the `Token Exchange Grant` + //providing an access token in request for a `delegation` token for a given resource / audience + DelegationTokenExchange(context.Context, string, ...tokenexchange.TokenExchangeOption) (*oauth2.Token, error) +} diff --git a/pkg/rp/verifier.go b/pkg/rp/verifier.go new file mode 100644 index 0000000..b82e6c2 --- /dev/null +++ b/pkg/rp/verifier.go @@ -0,0 +1,15 @@ +package rp + +import ( + "context" + + "github.com/caos/oidc/pkg/oidc" +) + +//Verifier implement the Token Response Validation as defined in OIDC specification +//https://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation +type Verifier interface { + + //Verify checks the access_token and id_token and returns the `id token claims` + Verify(ctx context.Context, accessToken, idTokenString string) (*oidc.IDTokenClaims, error) +} diff --git a/pkg/utils/cookie.go b/pkg/utils/cookie.go new file mode 100644 index 0000000..9e73e08 --- /dev/null +++ b/pkg/utils/cookie.go @@ -0,0 +1,110 @@ +package utils + +import ( + "errors" + "net/http" + + "github.com/gorilla/securecookie" +) + +type CookieHandler struct { + securecookie *securecookie.SecureCookie + secureOnly bool + sameSite http.SameSite + maxAge int + domain string +} + +func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *CookieHandler { + c := &CookieHandler{ + securecookie: securecookie.New(hashKey, encryptKey), + secureOnly: true, + sameSite: http.SameSiteLaxMode, + } + + for _, opt := range opts { + opt(c) + } + return c +} + +type CookieHandlerOpt func(*CookieHandler) + +func WithUnsecure() CookieHandlerOpt { + return func(c *CookieHandler) { + c.secureOnly = false + } +} + +func WithSameSite(sameSite http.SameSite) CookieHandlerOpt { + return func(c *CookieHandler) { + c.sameSite = sameSite + } +} + +func WithMaxAge(maxAge int) CookieHandlerOpt { + return func(c *CookieHandler) { + c.maxAge = maxAge + c.securecookie.MaxAge(maxAge) + } +} + +func WithDomain(domain string) CookieHandlerOpt { + return func(c *CookieHandler) { + c.domain = domain + } +} + +func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) { + cookie, err := r.Cookie(name) + if err != nil { + return "", err + } + var value string + if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil { + return "", err + } + return value, nil +} + +func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) { + value, err := c.CheckCookie(r, name) + if err != nil { + return "", err + } + if value != r.FormValue(name) { + return "", errors.New(name + " does not compare") + } + return value, nil +} + +func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error { + encoded, err := c.securecookie.Encode(name, value) + if err != nil { + return err + } + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: encoded, + Domain: c.domain, + Path: "/", + MaxAge: c.maxAge, + HttpOnly: true, + Secure: c.secureOnly, + SameSite: c.sameSite, + }) + return nil +} + +func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Domain: c.domain, + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: c.secureOnly, + SameSite: c.sameSite, + }) +} diff --git a/pkg/utils/crypto.go b/pkg/utils/crypto.go new file mode 100644 index 0000000..05acb75 --- /dev/null +++ b/pkg/utils/crypto.go @@ -0,0 +1,70 @@ +package utils + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "io" +) + +func EncryptAES(data string, key string) (string, error) { + encrypted, err := EncryptBytesAES([]byte(data), key) + if err != nil { + return "", err + } + + return base64.URLEncoding.EncodeToString(encrypted), nil +} + +func EncryptBytesAES(plainText []byte, key string) ([]byte, error) { + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return nil, err + } + + cipherText := make([]byte, aes.BlockSize+len(plainText)) + iv := cipherText[:aes.BlockSize] + if _, err = io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(cipherText[aes.BlockSize:], plainText) + + return cipherText, nil +} + +func DecryptAES(data string, key string) (string, error) { + text, err := base64.URLEncoding.DecodeString(data) + if err != nil { + return "", nil + } + decrypted, err := DecryptBytesAES(text, key) + if err != nil { + return "", err + } + return string(decrypted), nil +} + +func DecryptBytesAES(cipherText []byte, key string) ([]byte, error) { + + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return nil, err + } + + if len(cipherText) < aes.BlockSize { + err = errors.New("Ciphertext block size is too short!") + return nil, err + } + iv := cipherText[:aes.BlockSize] + cipherText = cipherText[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(cipherText, cipherText) + + return cipherText, err +} diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go new file mode 100644 index 0000000..bfdfacb --- /dev/null +++ b/pkg/utils/hash.go @@ -0,0 +1,30 @@ +package utils + +import ( + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "fmt" + "hash" + + "gopkg.in/square/go-jose.v2" +) + +func GetHashAlgorithm(sigAlgorithm jose.SignatureAlgorithm) (hash.Hash, error) { + switch sigAlgorithm { + case jose.RS256, jose.ES256, jose.PS256: + return sha256.New(), nil + case jose.RS384, jose.ES384, jose.PS384: + return sha512.New384(), nil + case jose.RS512, jose.ES512, jose.PS512: + return sha512.New(), nil + default: + return nil, fmt.Errorf("oidc: unsupported signing algorithm %q", sigAlgorithm) + } +} + +func HashString(hash hash.Hash, s string) string { + hash.Write([]byte(s)) // hash documents that Write will never return an error + sum := hash.Sum(nil)[:hash.Size()/2] + return base64.RawURLEncoding.EncodeToString(sum) +} diff --git a/pkg/utils/http.go b/pkg/utils/http.go new file mode 100644 index 0000000..6ad7083 --- /dev/null +++ b/pkg/utils/http.go @@ -0,0 +1,67 @@ +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gorilla/schema" +) + +var ( + DefaultHTTPClient = &http.Client{ + Timeout: time.Duration(30 * time.Second), + } +) + +func FormRequest(endpoint string, request interface{}) (*http.Request, error) { + form := make(map[string][]string) + encoder := schema.NewEncoder() + if err := encoder.Encode(request, form); err != nil { + return nil, err + } + body := strings.NewReader(url.Values(form).Encode()) + req, err := http.NewRequest("POST", endpoint, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req, nil +} + +func HttpRequest(client *http.Client, req *http.Request, response interface{}) error { + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("unable to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http status not ok: %s %s", resp.Status, body) + } + + err = json.Unmarshal(body, response) + if err != nil { + return fmt.Errorf("failed to unmarshal response: %v %s", err, body) + } + return nil +} + +func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) { + values := make(map[string][]string) + err := encoder.Encode(resp, values) + if err != nil { + return "", err + } + v := url.Values(values) + return v.Encode(), nil +} diff --git a/pkg/utils/marshal.go b/pkg/utils/marshal.go new file mode 100644 index 0000000..e279341 --- /dev/null +++ b/pkg/utils/marshal.go @@ -0,0 +1,21 @@ +package utils + +import ( + "encoding/json" + "net/http" + + "github.com/sirupsen/logrus" +) + +func MarshalJSON(w http.ResponseWriter, i interface{}) { + b, err := json.Marshal(i) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("content-type", "application/json") + _, err = w.Write(b) + if err != nil { + logrus.Error("error writing response") + } +} diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go new file mode 100644 index 0000000..5ffcd37 --- /dev/null +++ b/pkg/utils/strings.go @@ -0,0 +1,10 @@ +package utils + +func Contains(list []string, needle string) bool { + for _, item := range list { + if item == needle { + return true + } + } + return false +} diff --git a/pkg/utils/strings_test.go b/pkg/utils/strings_test.go new file mode 100644 index 0000000..86af2af --- /dev/null +++ b/pkg/utils/strings_test.go @@ -0,0 +1,48 @@ +package utils + +import "testing" + +func TestContains(t *testing.T) { + type args struct { + list []string + needle string + } + tests := []struct { + name string + args args + want bool + }{ + { + "empty list false", + args{[]string{}, "needle"}, + false, + }, + { + "list not containing false", + args{[]string{"list"}, "needle"}, + false, + }, + { + "list not containing empty needle false", + args{[]string{"list", "needle"}, ""}, + false, + }, + { + "list containing true", + args{[]string{"list", "needle"}, "needle"}, + true, + }, + { + "list containing empty needle true", + args{[]string{"list", "needle", ""}, ""}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Contains(tt.args.list, tt.args.needle); got != tt.want { + t.Errorf("Contains() = %v, want %v", got, tt.want) + } + }) + } +}