fix: use the same schema encoder everywhere (#299)
properly register SpaceDelimitedArray for all instances of schema.Encoder inside the oidc framework. Closes #295
This commit is contained in:
parent
fc1a80d274
commit
4dca29f1f9
5 changed files with 33 additions and 13 deletions
|
@ -8,11 +8,9 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/schema"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
|
@ -21,13 +19,7 @@ import (
|
||||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Encoder = func() httphelper.Encoder {
|
var Encoder = httphelper.Encoder(oidc.NewEncoder())
|
||||||
e := schema.NewEncoder()
|
|
||||||
e.RegisterEncoder(oidc.SpaceDelimitedArray{}, func(value reflect.Value) string {
|
|
||||||
return value.Interface().(oidc.SpaceDelimitedArray).Encode()
|
|
||||||
})
|
|
||||||
return e
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
// Discover calls the discovery endpoint of the provided issuer and returns its configuration
|
||||||
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
// It accepts an optional argument "wellknownUrl" which can be used to overide the dicovery endpoint url
|
||||||
|
|
|
@ -4,9 +4,11 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/schema"
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
@ -125,6 +127,16 @@ func (s SpaceDelimitedArray) Value() (driver.Value, error) {
|
||||||
return strings.Join(s, " "), nil
|
return strings.Join(s, " "), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewEncoder returns a schema Encoder with
|
||||||
|
// a registered encoder for SpaceDelimitedArray.
|
||||||
|
func NewEncoder() *schema.Encoder {
|
||||||
|
e := schema.NewEncoder()
|
||||||
|
e.RegisterEncoder(SpaceDelimitedArray{}, func(value reflect.Value) string {
|
||||||
|
return value.Interface().(SpaceDelimitedArray).Encode()
|
||||||
|
})
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
type Time time.Time
|
type Time time.Time
|
||||||
|
|
||||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||||
|
|
|
@ -3,10 +3,12 @@ package oidc
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/schema"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
)
|
)
|
||||||
|
@ -335,3 +337,20 @@ func TestSpaceDelimitatedArray_ValuerNil(t *testing.T) {
|
||||||
assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil")
|
assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEncoder(t *testing.T) {
|
||||||
|
type request struct {
|
||||||
|
Scopes SpaceDelimitedArray `schema:"scope"`
|
||||||
|
}
|
||||||
|
a := request{
|
||||||
|
Scopes: SpaceDelimitedArray{"foo", "bar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make(url.Values)
|
||||||
|
NewEncoder().Encode(a, values)
|
||||||
|
assert.Equal(t, url.Values{"scope": []string{"foo bar"}}, values)
|
||||||
|
|
||||||
|
var b request
|
||||||
|
schema.NewDecoder().Decode(&b, values)
|
||||||
|
assert.Equal(t, a, b)
|
||||||
|
}
|
||||||
|
|
|
@ -98,8 +98,6 @@ func TestParseDeviceCodeRequest(t *testing.T) {
|
||||||
name: "empty request",
|
name: "empty request",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
/* decoding a SpaceDelimitedArray is broken
|
|
||||||
https://github.com/zitadel/oidc/issues/295
|
|
||||||
{
|
{
|
||||||
name: "success",
|
name: "success",
|
||||||
req: &oidc.DeviceAuthorizationRequest{
|
req: &oidc.DeviceAuthorizationRequest{
|
||||||
|
@ -107,7 +105,6 @@ func TestParseDeviceCodeRequest(t *testing.T) {
|
||||||
ClientID: "web",
|
ClientID: "web",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -189,7 +189,7 @@ func newProvider(ctx context.Context, config *Config, storage Storage, issuer fu
|
||||||
o.decoder = schema.NewDecoder()
|
o.decoder = schema.NewDecoder()
|
||||||
o.decoder.IgnoreUnknownKeys(true)
|
o.decoder.IgnoreUnknownKeys(true)
|
||||||
|
|
||||||
o.encoder = schema.NewEncoder()
|
o.encoder = oidc.NewEncoder()
|
||||||
|
|
||||||
o.crypto = NewAESCrypto(config.CryptoKey)
|
o.crypto = NewAESCrypto(config.CryptoKey)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue