feat: Token Revocation, Request Object and OP Certification (#130)

FEATURES (and FIXES):
- support OAuth 2.0 Token Revocation [RFC 7009](https://datatracker.ietf.org/doc/html/rfc7009)
- handle request object using `request` parameter [OIDC Core 1.0 Request Object](https://openid.net/specs/openid-connect-core-1_0.html#RequestObject)
- handle response mode
- added some information to the discovery endpoint:
  - revocation_endpoint (added with token revocation) 
  - revocation_endpoint_auth_methods_supported (added with token revocation)
  - revocation_endpoint_auth_signing_alg_values_supported (added with token revocation)
  - token_endpoint_auth_signing_alg_values_supported (was missing)
  - introspection_endpoint_auth_signing_alg_values_supported (was missing)
  - request_object_signing_alg_values_supported (added with request object)
  - request_parameter_supported (added with request object)
 - fixed `removeUserinfoScopes ` now returns the scopes without "userinfo" scopes (profile, email, phone, addedd) [source diff](https://github.com/caos/oidc/pull/130/files#diff-fad50c8c0f065d4dbc49d6c6a38f09c992c8f5d651a479ba00e31b500543559eL170-R171)
- improved error handling (pkg/oidc/error.go) and fixed some wrong OAuth errors (e.g. `invalid_grant` instead of `invalid_request`)
- improved MarshalJSON and added MarshalJSONWithStatus
- removed deprecated PEM decryption from `BytesToPrivateKey`  [source diff](https://github.com/caos/oidc/pull/130/files#diff-fe246e428e399ccff599627c71764de51387b60b4df84c67de3febd0954e859bL11-L19)
- NewAccessTokenVerifier now uses correct (internal) `accessTokenVerifier` [source diff](https://github.com/caos/oidc/pull/130/files#diff-3a01c7500ead8f35448456ef231c7c22f8d291710936cac91de5edeef52ffc72L52-R52)

BREAKING CHANGE:
- move functions from `utils` package into separate packages
- added various methods to the (OP) `Configuration` interface [source diff](https://github.com/caos/oidc/pull/130/files#diff-2538e0dfc772fdc37f057aecd6fcc2943f516c24e8be794cce0e368a26d20a82R19-R32)
- added revocationEndpoint to `WithCustomEndpoints ` [source diff](https://github.com/caos/oidc/pull/130/files#diff-19ae13a743eb7cebbb96492798b1bec556673eb6236b1387e38d722900bae1c3L355-R391)
- remove unnecessary context parameter from JWTProfileExchange [source diff](https://github.com/caos/oidc/pull/130/files#diff-4ed8f6affa4a9631fa8a034b3d5752fbb6a819107141aae00029014e950f7b4cL14)
This commit is contained in:
Livio Amstutz 2021-11-02 13:21:35 +01:00 committed by GitHub
parent 763d3334e7
commit eb10752e48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
63 changed files with 1738 additions and 624 deletions

110
pkg/http/cookie.go Normal file
View file

@ -0,0 +1,110 @@
package http
import (
"errors"
"net/http"
"github.com/gorilla/securecookie"
)
type CookieHandler struct {
securecookie *securecookie.SecureCookie
secureOnly bool
sameSite http.SameSite
maxAge int
domain string
}
func NewCookieHandler(hashKey, encryptKey []byte, opts ...CookieHandlerOpt) *CookieHandler {
c := &CookieHandler{
securecookie: securecookie.New(hashKey, encryptKey),
secureOnly: true,
sameSite: http.SameSiteLaxMode,
}
for _, opt := range opts {
opt(c)
}
return c
}
type CookieHandlerOpt func(*CookieHandler)
func WithUnsecure() CookieHandlerOpt {
return func(c *CookieHandler) {
c.secureOnly = false
}
}
func WithSameSite(sameSite http.SameSite) CookieHandlerOpt {
return func(c *CookieHandler) {
c.sameSite = sameSite
}
}
func WithMaxAge(maxAge int) CookieHandlerOpt {
return func(c *CookieHandler) {
c.maxAge = maxAge
c.securecookie.MaxAge(maxAge)
}
}
func WithDomain(domain string) CookieHandlerOpt {
return func(c *CookieHandler) {
c.domain = domain
}
}
func (c *CookieHandler) CheckCookie(r *http.Request, name string) (string, error) {
cookie, err := r.Cookie(name)
if err != nil {
return "", err
}
var value string
if err := c.securecookie.Decode(name, cookie.Value, &value); err != nil {
return "", err
}
return value, nil
}
func (c *CookieHandler) CheckQueryCookie(r *http.Request, name string) (string, error) {
value, err := c.CheckCookie(r, name)
if err != nil {
return "", err
}
if value != r.FormValue(name) {
return "", errors.New(name + " does not compare")
}
return value, nil
}
func (c *CookieHandler) SetCookie(w http.ResponseWriter, name, value string) error {
encoded, err := c.securecookie.Encode(name, value)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encoded,
Domain: c.domain,
Path: "/",
MaxAge: c.maxAge,
HttpOnly: true,
Secure: c.secureOnly,
SameSite: c.sameSite,
})
return nil
}
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Domain: c.domain,
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: c.secureOnly,
SameSite: c.sameSite,
})
}

107
pkg/http/http.go Normal file
View file

@ -0,0 +1,107 @@
package http
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"time"
)
var (
DefaultHTTPClient = &http.Client{
Timeout: 30 * time.Second,
}
)
type Decoder interface {
Decode(dst interface{}, src map[string][]string) error
}
type Encoder interface {
Encode(src interface{}, dst map[string][]string) error
}
type FormAuthorization func(url.Values)
type RequestAuthorization func(*http.Request)
func AuthorizeBasic(user, password string) RequestAuthorization {
return func(req *http.Request) {
req.SetBasicAuth(url.QueryEscape(user), url.QueryEscape(password))
}
}
func FormRequest(endpoint string, request interface{}, encoder Encoder, authFn interface{}) (*http.Request, error) {
form := url.Values{}
if err := encoder.Encode(request, form); err != nil {
return nil, err
}
if fn, ok := authFn.(FormAuthorization); ok {
fn(form)
}
body := strings.NewReader(form.Encode())
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return nil, err
}
if fn, ok := authFn.(RequestAuthorization); ok {
fn(req)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req, nil
}
func HttpRequest(client *http.Client, req *http.Request, response interface{}) error {
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("http status not ok: %s %s", resp.Status, body)
}
err = json.Unmarshal(body, response)
if err != nil {
return fmt.Errorf("failed to unmarshal response: %v %s", err, body)
}
return nil
}
func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) {
values := make(map[string][]string)
err := encoder.Encode(resp, values)
if err != nil {
return "", err
}
v := url.Values(values)
return v.Encode(), nil
}
func StartServer(ctx context.Context, port string) {
server := &http.Server{Addr: port}
go func() {
if err := server.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("ListenAndServe(): %v", err)
}
}()
go func() {
<-ctx.Done()
ctxShutdown, cancelShutdown := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelShutdown()
err := server.Shutdown(ctxShutdown)
if err != nil {
log.Fatalf("Shutdown(): %v", err)
}
}()
}

45
pkg/http/marshal.go Normal file
View file

@ -0,0 +1,45 @@
package http
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"reflect"
)
func MarshalJSON(w http.ResponseWriter, i interface{}) {
MarshalJSONWithStatus(w, i, http.StatusOK)
}
func MarshalJSONWithStatus(w http.ResponseWriter, i interface{}, status int) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(status)
if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) {
return
}
err := json.NewEncoder(w).Encode(i)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func ConcatenateJSON(first, second []byte) ([]byte, error) {
if !bytes.HasSuffix(first, []byte{'}'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", first)
}
if !bytes.HasPrefix(second, []byte{'{'}) {
return nil, fmt.Errorf("jws: invalid JSON %s", second)
}
// check empty
if len(first) == 2 {
return second, nil
}
if len(second) == 2 {
return first, nil
}
first[len(first)-1] = ','
first = append(first, second[1:]...)
return first, nil
}

156
pkg/http/marshal_test.go Normal file
View file

@ -0,0 +1,156 @@
package http
import (
"bytes"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConcatenateJSON(t *testing.T) {
type args struct {
first []byte
second []byte
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
"invalid first part, error",
args{
[]byte(`invalid`),
[]byte(`{"some": "thing"}`),
},
nil,
true,
},
{
"invalid second part, error",
args{
[]byte(`{"some": "thing"}`),
[]byte(`invalid`),
},
nil,
true,
},
{
"both valid, merged",
args{
[]byte(`{"some": "thing"}`),
[]byte(`{"another": "thing"}`),
},
[]byte(`{"some": "thing","another": "thing"}`),
false,
},
{
"first empty",
args{
[]byte(`{}`),
[]byte(`{"some": "thing"}`),
},
[]byte(`{"some": "thing"}`),
false,
},
{
"second empty",
args{
[]byte(`{"some": "thing"}`),
[]byte(`{}`),
},
[]byte(`{"some": "thing"}`),
false,
},
{
"both empty",
args{
[]byte(`{}`),
[]byte(`{}`),
},
[]byte(`{}`),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ConcatenateJSON(tt.args.first, tt.args.second)
if (err != nil) != tt.wantErr {
t.Errorf("ConcatenateJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !bytes.Equal(got, tt.want) {
t.Errorf("ConcatenateJSON() got = %v, want %v", string(got), tt.want)
}
})
}
}
func TestMarshalJSONWithStatus(t *testing.T) {
type args struct {
i interface{}
status int
}
type res struct {
statusCode int
body string
}
tests := []struct {
name string
args args
res res
}{
{
"empty ok",
args{
nil,
200,
},
res{
200,
"",
},
},
{
"string ok",
args{
"ok",
200,
},
res{
200,
`"ok"
`,
},
},
{
"struct ok",
args{
struct {
Test string `json:"test"`
}{"ok"},
200,
},
res{
200,
`{"test":"ok"}
`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
MarshalJSONWithStatus(w, tt.args.i, tt.args.status)
assert.Equal(t, tt.res.statusCode, w.Result().StatusCode)
assert.Equal(t, "application/json", w.Header().Get("content-type"))
assert.Equal(t, tt.res.body, w.Body.String())
})
}
}