en/decoding abstraction

This commit is contained in:
Livio Amstutz 2020-07-13 07:30:00 +02:00
parent 2966355b0e
commit 4a10ffaaa2
13 changed files with 367 additions and 69 deletions

View file

@ -8,7 +8,6 @@ import (
"strings" "strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/rp"
@ -17,33 +16,37 @@ import (
type Authorizer interface { type Authorizer interface {
Storage() Storage Storage() Storage
Decoder() *schema.Decoder Decoder() utils.Decoder
Encoder() *schema.Encoder Encoder() utils.Encoder
Signer() Signer Signer() Signer
IDTokenVerifier() rp.Verifier IDTokenVerifier() rp.Verifier
Crypto() Crypto Crypto() Crypto
Issuer() string Issuer() string
} }
type ValidationAuthorizer interface { //AuthorizeValidator is an extension of Authorizer interface
//implementing it's own validation mechanism for the auth request
type AuthorizeValidator interface {
Authorizer Authorizer
ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error) ValidateAuthRequest(context.Context, *oidc.AuthRequest, Storage, rp.Verifier) (string, error)
} }
//ValidationAuthorizer is an extension of Authorizer interface
//implementing it's own validation mechanism for the auth request
//
//Deprecated: ValidationAuthorizer exists for historical compatibility. Use ValidationAuthorizer itself
type ValidationAuthorizer AuthorizeValidator
//Authorize handles the authorization request, including
//parsing, validating, storing and finally redirecting to the login handler
func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) { func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
err := r.ParseForm() authReq, err := ParseAuthorizeRequest(r, authorizer.Decoder())
if err != nil { if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest("cannot parse form"), authorizer.Encoder()) AuthRequestError(w, r, authReq, err, authorizer.Encoder())
return
}
authReq := new(oidc.AuthRequest)
err = authorizer.Decoder().Decode(authReq, r.Form)
if err != nil {
AuthRequestError(w, r, nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err)), authorizer.Encoder())
return return
} }
validation := ValidateAuthRequest validation := ValidateAuthRequest
if validater, ok := authorizer.(ValidationAuthorizer); ok { if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest validation = validater.ValidateAuthRequest
} }
userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier()) userID, err := validation(r.Context(), authReq, authorizer.Storage(), authorizer.IDTokenVerifier())
@ -64,6 +67,19 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
RedirectToLogin(req.GetID(), client, w, r) RedirectToLogin(req.GetID(), client, w, r)
} }
func ParseAuthorizeRequest(r *http.Request, decoder utils.Decoder) (*oidc.AuthRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, ErrInvalidRequest("cannot parse form")
}
authReq := new(oidc.AuthRequest)
err = decoder.Decode(authReq, r.Form)
if err != nil {
return nil, ErrInvalidRequest(fmt.Sprintf("cannot parse auth request: %v", err))
}
return authReq, nil
}
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) { func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier rp.Verifier) (string, error) {
if err := ValidateAuthReqScopes(authReq.Scopes); err != nil { if err := ValidateAuthReqScopes(authReq.Scopes); err != nil {
return "", err return "", err
@ -121,10 +137,16 @@ func ValidateAuthReqRedirectURI(ctx context.Context, uri, client_id string, resp
} }
func ValidateAuthReqResponseType(responseType oidc.ResponseType) error { func ValidateAuthReqResponseType(responseType oidc.ResponseType) error {
if responseType == "" { switch responseType {
return ErrInvalidRequest("response_type empty") case oidc.ResponseTypeCode,
} oidc.ResponseTypeIDToken,
oidc.ResponseTypeIDTokenOnly:
return nil return nil
case "":
return ErrInvalidRequest("response_type empty")
default:
return ErrInvalidRequest("response_type invalid")
}
} }
func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) { func ValidateAuthReqIDTokenHint(ctx context.Context, idTokenHint string, verifier rp.Verifier) (string, error) {

View file

@ -3,66 +3,140 @@ package op_test
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "net/url"
"reflect"
"testing" "testing"
"github.com/gorilla/schema"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/op" "github.com/caos/oidc/pkg/op"
"github.com/caos/oidc/pkg/op/mock" "github.com/caos/oidc/pkg/op/mock"
"github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/rp"
rp_mock "github.com/caos/oidc/pkg/rp/mock"
"github.com/caos/oidc/pkg/utils"
) )
func TestAuthorize(t *testing.T) { //
// testCallback := func(t *testing.T, clienID string) callbackHandler { //func TestAuthorize(t *testing.T) {
// return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) { // // testCallback := func(t *testing.T, clienID string) callbackHandler {
// // require.Equal(t, clientID, client.) // // return func(authReq *oidc.AuthRequest, client oidc.Client, w http.ResponseWriter, r *http.Request) {
// } // // // require.Equal(t, clientID, client.)
// } // // }
// testErr := func(t *testing.T, expected error) errorHandler { // // }
// return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) { // // testErr := func(t *testing.T, expected error) errorHandler {
// require.Equal(t, expected, err) // // return func(w http.ResponseWriter, r *http.Request, authReq *oidc.AuthRequest, err error) {
// } // // require.Equal(t, expected, err)
// } // // }
// // }
// type args struct {
// w http.ResponseWriter
// r *http.Request
// authorizer op.Authorizer
// }
// tests := []struct {
// name string
// args args
// }{
// {
// "parsing fails",
// args{
// httptest.NewRecorder(),
// &http.Request{Method: "POST", Body: nil},
// mock.NewAuthorizerExpectValid(t, true),
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse form")),
// },
// },
// {
// "decoding fails",
// args{
// httptest.NewRecorder(),
// func() *http.Request {
// r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo"))
// r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// return r
// }(),
// mock.NewAuthorizerExpectValid(t, true),
// // testCallback(t, ""),
// // testErr(t, ErrInvalidRequest("cannot parse auth request")),
// },
// },
// // {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer)
// })
// }
//}
func TestParseAuthorizeRequest(t *testing.T) {
type args struct { type args struct {
w http.ResponseWriter
r *http.Request r *http.Request
authorizer op.Authorizer decoder utils.Decoder
}
type res struct {
want *oidc.AuthRequest
err bool
} }
tests := []struct { tests := []struct {
name string name string
args args args args
res res
}{ }{
{ {
"parsing fails", "parsing form error",
args{ args{
httptest.NewRecorder(), &http.Request{URL: &url.URL{RawQuery: "invalid=%%param"}},
&http.Request{Method: "POST", Body: nil}, schema.NewDecoder(),
mock.NewAuthorizerExpectValid(t, true), },
// testCallback(t, ""), res{
// testErr(t, ErrInvalidRequest("cannot parse form")), nil,
true,
}, },
}, },
{ {
"decoding fails", "decoding error",
args{ args{
httptest.NewRecorder(), &http.Request{URL: &url.URL{RawQuery: "unknown=value"}},
func() *http.Request { func() utils.Decoder {
r := httptest.NewRequest("POST", "/authorize", strings.NewReader("client_id=foo")) decoder := schema.NewDecoder()
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") decoder.IgnoreUnknownKeys(false)
return r return decoder
}(), }(),
mock.NewAuthorizerExpectValid(t, true), },
// testCallback(t, ""), res{
// testErr(t, ErrInvalidRequest("cannot parse auth request")), nil,
true,
},
},
{
"parsing ok",
args{
&http.Request{URL: &url.URL{RawQuery: "scope=openid"}},
func() utils.Decoder {
decoder := schema.NewDecoder()
decoder.IgnoreUnknownKeys(false)
return decoder
}(),
},
res{
&oidc.AuthRequest{Scopes: oidc.Scopes{"openid"}},
false,
}, },
}, },
// {"decoding fails", args{httptest.NewRecorder(), &http.Request{}, mock.NewAuthorizerExpectValid(t), nil, testErr(t, nil)}},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
op.Authorize(tt.args.w, tt.args.r, tt.args.authorizer) got, err := op.ParseAuthorizeRequest(tt.args.r, tt.args.decoder)
if (err != nil) != tt.res.err {
t.Errorf("ParseAuthorizeRequest() error = %v, wantErr %v", err, tt.res.err)
}
if !reflect.DeepEqual(got, tt.res.want) {
t.Errorf("ParseAuthorizeRequest() got = %v, want %v", got, tt.res.want)
}
}) })
} }
} }
@ -228,6 +302,115 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
} }
} }
func TestValidateAuthReqResponseType(t *testing.T) {
type args struct {
responseType oidc.ResponseType
}
type res struct {
err bool
}
tests := []struct {
name string
args args
res res
}{
{
"code no error",
args{"code"},
res{false},
},
{
"id_token token no error",
args{"id_token token"},
res{false},
},
{
"id_token no error",
args{"id_token"},
res{false},
},
{
"no response_type error",
args{},
res{true},
},
{
"invalid response_type error",
args{"invalid"},
res{true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := op.ValidateAuthReqResponseType(tt.args.responseType); (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqResponseType() error = %v, wantErr %v", err, tt.res.err)
}
})
}
}
func TestValidateAuthReqIDTokenHint(t *testing.T) {
type args struct {
idTokenHint string
verifier rp.Verifier
}
type res struct {
userID string
err bool
}
tests := []struct {
name string
args args
res res
}{
{
"no id_token_hint, no id and ok",
args{
"",
nil,
},
res{
"",
false,
},
},
{
"invalid id_token_hint, no id and error",
args{
"invalid",
rp_mock.NewMockVerifierExpectInvalid(t),
},
res{
"",
true,
},
},
{
"no id_token_hint ok",
args{
"valid",
rp_mock.NewMockVerifierExpectValid(t),
},
res{
"id",
false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := op.ValidateAuthReqIDTokenHint(nil, tt.args.idTokenHint, tt.args.verifier)
if (err != nil) != tt.res.err {
t.Errorf("ValidateAuthReqIDTokenHint() error = %v, wantErr %v", err, tt.res.err)
return
}
if got != tt.res.userID {
t.Errorf("ValidateAuthReqIDTokenHint() got = %v, want %v", got, tt.res.userID)
}
})
}
}
func TestRedirectToLogin(t *testing.T) { func TestRedirectToLogin(t *testing.T) {
type args struct { type args struct {
authReqID string authReqID string

View file

@ -56,7 +56,7 @@ func TestValidateIssuer(t *testing.T) {
{ {
"localhost with http ok", "localhost with http ok",
args{"http://localhost:9999"}, args{"http://localhost:9999"},
true, false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -13,6 +13,7 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/rp"
"github.com/caos/oidc/pkg/utils"
) )
const ( const (
@ -247,11 +248,11 @@ func (p *DefaultOP) VerifySignature(ctx context.Context, jws *jose.JSONWebSignat
return payload, err return payload, err
} }
func (p *DefaultOP) Decoder() *schema.Decoder { func (p *DefaultOP) Decoder() utils.Decoder {
return p.decoder return p.decoder
} }
func (p *DefaultOP) Encoder() *schema.Encoder { func (p *DefaultOP) Encoder() utils.Encoder {
return p.encoder return p.encoder
} }

View file

@ -4,8 +4,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
@ -45,7 +43,7 @@ type ErrAuthRequest interface {
GetState() string GetState() string
} }
func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder *schema.Encoder) { func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, encoder utils.Encoder) {
if authReq == nil { if authReq == nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View file

@ -7,8 +7,8 @@ package mock
import ( import (
op "github.com/caos/oidc/pkg/op" op "github.com/caos/oidc/pkg/op"
rp "github.com/caos/oidc/pkg/rp" rp "github.com/caos/oidc/pkg/rp"
utils "github.com/caos/oidc/pkg/utils"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
schema "github.com/gorilla/schema"
reflect "reflect" reflect "reflect"
) )
@ -50,10 +50,10 @@ func (mr *MockAuthorizerMockRecorder) Crypto() *gomock.Call {
} }
// Decoder mocks base method // Decoder mocks base method
func (m *MockAuthorizer) Decoder() *schema.Decoder { func (m *MockAuthorizer) Decoder() utils.Decoder {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decoder") ret := m.ctrl.Call(m, "Decoder")
ret0, _ := ret[0].(*schema.Decoder) ret0, _ := ret[0].(utils.Decoder)
return ret0 return ret0
} }
@ -64,10 +64,10 @@ func (mr *MockAuthorizerMockRecorder) Decoder() *gomock.Call {
} }
// Encoder mocks base method // Encoder mocks base method
func (m *MockAuthorizer) Encoder() *schema.Encoder { func (m *MockAuthorizer) Encoder() utils.Encoder {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encoder") ret := m.ctrl.Call(m, "Encoder")
ret0, _ := ret[0].(*schema.Encoder) ret0, _ := ret[0].(utils.Encoder)
return ret0 return ret0
} }

View file

@ -6,11 +6,11 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp" "github.com/caos/oidc/pkg/rp"
"github.com/gorilla/schema" "github.com/caos/oidc/pkg/utils"
) )
type SessionEnder interface { type SessionEnder interface {
Decoder() *schema.Decoder Decoder() utils.Decoder
Storage() Storage Storage() Storage
IDTokenVerifier() rp.Verifier IDTokenVerifier() rp.Verifier
DefaultLogoutRedirectURI() string DefaultLogoutRedirectURI() string
@ -39,7 +39,7 @@ func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
http.Redirect(w, r, session.RedirectURI, http.StatusFound) http.Redirect(w, r, session.RedirectURI, http.StatusFound)
} }
func ParseEndSessionRequest(r *http.Request, decoder *schema.Decoder) (*oidc.EndSessionRequest, error) { func ParseEndSessionRequest(r *http.Request, decoder utils.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, ErrInvalidRequest("error parsing form")

View file

@ -5,8 +5,6 @@ import (
"errors" "errors"
"net/http" "net/http"
"github.com/gorilla/schema"
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
) )
@ -14,7 +12,7 @@ import (
type Exchanger interface { type Exchanger interface {
Issuer() string Issuer() string
Storage() Storage Storage() Storage
Decoder() *schema.Decoder Decoder() utils.Decoder
Signer() Signer Signer() Signer
Crypto() Crypto Crypto() Crypto
AuthMethodPostSupported() bool AuthMethodPostSupported() bool
@ -42,7 +40,7 @@ func CodeExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
utils.MarshalJSON(w, resp) utils.MarshalJSON(w, resp)
} }
func ParseAccessTokenRequest(r *http.Request, decoder *schema.Decoder) (*oidc.AccessTokenRequest, error) { func ParseAccessTokenRequest(r *http.Request, decoder utils.Decoder) (*oidc.AccessTokenRequest, error) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
return nil, ErrInvalidRequest("error parsing form") return nil, ErrInvalidRequest("error parsing form")

View file

@ -7,11 +7,10 @@ import (
"github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/utils" "github.com/caos/oidc/pkg/utils"
"github.com/gorilla/schema"
) )
type UserinfoProvider interface { type UserinfoProvider interface {
Decoder() *schema.Decoder Decoder() utils.Decoder
Crypto() Crypto Crypto() Crypto
Storage() Storage Storage() Storage
} }
@ -35,7 +34,7 @@ func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoP
utils.MarshalJSON(w, info) utils.MarshalJSON(w, info)
} }
func getAccessToken(r *http.Request, decoder *schema.Decoder) (string, error) { func getAccessToken(r *http.Request, decoder utils.Decoder) (string, error) {
authHeader := r.Header.Get("authorization") authHeader := r.Header.Get("authorization")
if authHeader != "" { if authHeader != "" {
parts := strings.Split(authHeader, "Bearer ") parts := strings.Split(authHeader, "Bearer ")

3
pkg/rp/mock/generate.go Normal file
View file

@ -0,0 +1,3 @@
package mock
//go:generate mockgen -package mock -destination ./verifier.mock.go github.com/caos/oidc/pkg/rp Verifier

View file

@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/oidc/pkg/rp (interfaces: Verifier)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
oidc "github.com/caos/oidc/pkg/oidc"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockVerifier is a mock of Verifier interface
type MockVerifier struct {
ctrl *gomock.Controller
recorder *MockVerifierMockRecorder
}
// MockVerifierMockRecorder is the mock recorder for MockVerifier
type MockVerifierMockRecorder struct {
mock *MockVerifier
}
// NewMockVerifier creates a new mock instance
func NewMockVerifier(ctrl *gomock.Controller) *MockVerifier {
mock := &MockVerifier{ctrl: ctrl}
mock.recorder = &MockVerifierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockVerifier) EXPECT() *MockVerifierMockRecorder {
return m.recorder
}
// Verify mocks base method
func (m *MockVerifier) Verify(arg0 context.Context, arg1, arg2 string) (*oidc.IDTokenClaims, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2)
ret0, _ := ret[0].(*oidc.IDTokenClaims)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Verify indicates an expected call of Verify
func (mr *MockVerifierMockRecorder) Verify(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockVerifier)(nil).Verify), arg0, arg1, arg2)
}

View file

@ -0,0 +1,37 @@
package mock
import (
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/caos/oidc/pkg/oidc"
"github.com/caos/oidc/pkg/rp"
)
func NewVerifier(t *testing.T) rp.Verifier {
return NewMockVerifier(gomock.NewController(t))
}
func NewMockVerifierExpectInvalid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyInvalid(m)
return m
}
func ExpectVerifyInvalid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("invalid"))
}
func NewMockVerifierExpectValid(t *testing.T) rp.Verifier {
m := NewVerifier(t)
ExpectVerifyValid(m)
return m
}
func ExpectVerifyValid(v rp.Verifier) {
mock := v.(*MockVerifier)
mock.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).Return(&oidc.IDTokenClaims{Userinfo: oidc.Userinfo{Subject: "id"}}, nil)
}

View file

@ -18,6 +18,13 @@ var (
} }
) )
type Decoder interface {
Decode(dst interface{}, src map[string][]string) error
}
type Encoder interface {
Encode(src interface{}, dst map[string][]string) error
}
func FormRequest(endpoint string, request interface{}) (*http.Request, error) { func FormRequest(endpoint string, request interface{}) (*http.Request, error) {
form := make(map[string][]string) form := make(map[string][]string)
encoder := schema.NewEncoder() encoder := schema.NewEncoder()
@ -56,7 +63,7 @@ func HttpRequest(client *http.Client, req *http.Request, response interface{}) e
return nil return nil
} }
func URLEncodeResponse(resp interface{}, encoder *schema.Encoder) (string, error) { func URLEncodeResponse(resp interface{}, encoder Encoder) (string, error) {
values := make(map[string][]string) values := make(map[string][]string)
err := encoder.Encode(resp, values) err := encoder.Encode(resp, values)
if err != nil { if err != nil {