diff --git a/pkg/oidc/token_request.go b/pkg/oidc/token_request.go index 56c4f1c..989e792 100644 --- a/pkg/oidc/token_request.go +++ b/pkg/oidc/token_request.go @@ -32,6 +32,11 @@ const ( ClientAssertionTypeJWTAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ) +var AllGrantTypes = []GrantType{ + GrantTypeCode, GrantTypeRefreshToken, GrantTypeClientCredentials, + GrantTypeBearer, GrantTypeTokenExchange, GrantTypeImplicit, + ClientAssertionTypeJWTAssertion} + type GrantType string type TokenRequest interface { diff --git a/pkg/oidc/types.go b/pkg/oidc/types.go index b6a75f4..1260798 100644 --- a/pkg/oidc/types.go +++ b/pkg/oidc/types.go @@ -1,7 +1,9 @@ package oidc import ( + "database/sql/driver" "encoding/json" + "fmt" "strings" "time" @@ -95,6 +97,34 @@ func (s *SpaceDelimitedArray) UnmarshalJSON(data []byte) error { return nil } +func (s *SpaceDelimitedArray) Scan(src interface{}) error { + if src == nil { + *s = nil + return nil + } + switch v := src.(type) { + case string: + if len(v) == 0 { + *s = SpaceDelimitedArray{} + return nil + } + *s = strings.Split(v, " ") + case []byte: + if len(v) == 0 { + *s = SpaceDelimitedArray{} + return nil + } + *s = strings.Split(string(v), " ") + default: + return fmt.Errorf("cannot convert %T to SpaceDelimitedArray", src) + } + return nil +} + +func (s SpaceDelimitedArray) Value() (driver.Value, error) { + return strings.Join(s, " "), nil +} + type Time time.Time func (t *Time) UnmarshalJSON(data []byte) error { diff --git a/pkg/oidc/types_test.go b/pkg/oidc/types_test.go index c03a775..6c62c40 100644 --- a/pkg/oidc/types_test.go +++ b/pkg/oidc/types_test.go @@ -3,6 +3,8 @@ package oidc import ( "bytes" "encoding/json" + "strconv" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -228,6 +230,7 @@ func TestScopes_UnmarshalText(t *testing.T) { }) } } + func TestScopes_MarshalText(t *testing.T) { type args struct { scopes SpaceDelimitedArray @@ -294,3 +297,41 @@ func TestScopes_MarshalText(t *testing.T) { }) } } + +func TestSpaceDelimitatedArray_ValuerNotNil(t *testing.T) { + inputs := [][]string{ + {"two", "elements"}, + {"one"}, + { /*zero*/ }, + } + for _, input := range inputs { + t.Run(strconv.Itoa(len(input))+strings.Join(input, "_"), func(t *testing.T) { + sda := SpaceDelimitedArray(input) + dbValue, err := sda.Value() + if !assert.NoError(t, err, "Value") { + return + } + var reversed SpaceDelimitedArray + err = reversed.Scan(dbValue) + if assert.NoError(t, err, "Scan string") { + assert.Equal(t, sda, reversed, "scan string") + } + reversed = nil + dbValueString, ok := dbValue.(string) + if assert.True(t, ok, "dbValue is string") { + err = reversed.Scan([]byte(dbValueString)) + if assert.NoError(t, err, "Scan bytes") { + assert.Equal(t, sda, reversed, "scan bytes") + } + } + }) + } +} + +func TestSpaceDelimitatedArray_ValuerNil(t *testing.T) { + var reversed SpaceDelimitedArray + err := reversed.Scan(nil) + if assert.NoError(t, err, "Scan nil") { + assert.Equal(t, SpaceDelimitedArray(nil), reversed, "scan nil") + } +}