resolve review comments
This commit is contained in:
parent
c9ab349d63
commit
08fab97786
4 changed files with 9 additions and 178 deletions
|
@ -154,29 +154,3 @@ func ClientIDFromRequest(r *http.Request, p ClientProvider) (clientID string, au
|
|||
}
|
||||
return data.ClientID, false, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// ClientFromRequest wraps ClientIDFromRequest and obtains the Client from storage.
|
||||
// If the client id was not authenticated, the client from storage does not have
|
||||
// oidc.AuthMethodNone set, an error is returned.
|
||||
func ClientFromRequest(r *http.Request, p ClientProvider) (Client, error) {
|
||||
clientID, authenticated, err := ClientIDFromRequest(r, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := p.Storage().GetClientByClientID(r.Context(), clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||
return nil, oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||
}
|
||||
}
|
||||
|
||||
return client, err
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -199,7 +199,7 @@ func TestClientIDFromRequest(t *testing.T) {
|
|||
wantAuthenticated: false,
|
||||
},
|
||||
{
|
||||
name: "unauthenticated",
|
||||
name: "authenticated",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
|
@ -251,144 +251,3 @@ func TestClientIDFromRequest(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func TestClientFromRequest(t *testing.T) {
|
||||
publicClient := func() op.Client {
|
||||
c := mock.NewMockClient(gomock.NewController(t))
|
||||
c.EXPECT().AuthMethod().Return(oidc.AuthMethodNone)
|
||||
return c
|
||||
}
|
||||
privateClient := func() op.Client {
|
||||
c := mock.NewMockClient(gomock.NewController(t))
|
||||
c.EXPECT().AuthMethod().Return(oidc.AuthMethodPrivateKeyJWT)
|
||||
return c
|
||||
}
|
||||
|
||||
type args struct {
|
||||
body io.Reader
|
||||
p op.ClientProvider
|
||||
}
|
||||
type basicAuth struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
basicAuth *basicAuth
|
||||
wantClient bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing client id",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: mock.NewStorage(t),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "get client error",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(nil, errors.New("something"))
|
||||
return s
|
||||
}(),
|
||||
},
|
||||
},
|
||||
basicAuth: &basicAuth{
|
||||
username: "foo",
|
||||
password: "bar",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "authenticated",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().AuthorizeClientIDSecret(context.Background(), "foo", "bar").Return(nil)
|
||||
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(mock.NewClient(t), nil)
|
||||
return s
|
||||
}(),
|
||||
},
|
||||
},
|
||||
basicAuth: &basicAuth{
|
||||
username: "foo",
|
||||
password: "bar",
|
||||
},
|
||||
wantClient: true,
|
||||
},
|
||||
{
|
||||
name: "public",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{
|
||||
"client_id": []string{"foo"},
|
||||
}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(publicClient(), nil)
|
||||
return s
|
||||
}(),
|
||||
},
|
||||
},
|
||||
wantClient: true,
|
||||
},
|
||||
{
|
||||
name: "false public",
|
||||
args: args{
|
||||
body: strings.NewReader(
|
||||
url.Values{
|
||||
"client_id": []string{"foo"},
|
||||
}.Encode(),
|
||||
),
|
||||
p: testClientProvider{
|
||||
storage: func() op.Storage {
|
||||
s := mock.NewMockStorage(gomock.NewController(t))
|
||||
s.EXPECT().GetClientByClientID(context.Background(), "foo").Return(privateClient(), nil)
|
||||
return s
|
||||
}(),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := httptest.NewRequest(http.MethodPost, "/foo", tt.args.body)
|
||||
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if tt.basicAuth != nil {
|
||||
r.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
|
||||
}
|
||||
|
||||
got, err := op.ClientFromRequest(r, tt.args.p)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
if tt.wantClient {
|
||||
assert.NotNil(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -16,8 +16,8 @@ import (
|
|||
)
|
||||
|
||||
type DeviceAuthorizationConfig struct {
|
||||
Lifetime int
|
||||
PollInterval int
|
||||
Lifetime time.Duration
|
||||
PollInterval time.Duration
|
||||
UserFormURL string // the URL where the user must go to authorize the device
|
||||
UserCode UserCodeConfig
|
||||
}
|
||||
|
@ -71,12 +71,12 @@ func DeviceAuthorization(w http.ResponseWriter, r *http.Request, o OpenIDProvide
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.CharAmount)
|
||||
userCode, err := NewUserCode([]rune(config.UserCode.CharSet), config.UserCode.CharAmount, config.UserCode.DashInterval)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expires := time.Now().Add(time.Duration(config.Lifetime) * time.Second)
|
||||
expires := time.Now().Add(config.Lifetime)
|
||||
err = storage.StoreDeviceAuthorization(r.Context(), req.ClientID, deviceCode, userCode, expires, req.Scopes)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -196,11 +196,9 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !clientAuthenticated {
|
||||
if m := client.AuthMethod(); m != oidc.AuthMethodNone { // Livio: Does this mean "public" client?
|
||||
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription(fmt.Sprintf("required client auth method: %s", m))
|
||||
}
|
||||
if !clientAuthenticated && !IsConfidentialType(client) {
|
||||
return oidc.ErrInvalidClient().WithParent(ErrNoClientCredentials).
|
||||
WithDescription("confidential client requires authentication")
|
||||
}
|
||||
|
||||
tokenRequest := &deviceAccessTokenRequest{
|
||||
|
|
|
@ -36,7 +36,7 @@ func TestNewDeviceCode(t *testing.T) {
|
|||
})
|
||||
})
|
||||
|
||||
t.Run("dirrent lengths, rand reader", func(t *testing.T) {
|
||||
t.Run("different lengths, rand reader", func(t *testing.T) {
|
||||
for i := 1; i <= 32; i++ {
|
||||
got, err := NewDeviceCode(i)
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue