chore: test all routes

Co-authored-by: David Sharnoff <dsharnoff@singlestore.com>
This commit is contained in:
Tim Möhlmann 2023-03-15 15:32:14 +02:00 committed by GitHub
parent 711a194b50
commit 26d8e32636
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 467 additions and 62 deletions

View file

@ -37,8 +37,8 @@ type AuthRequest struct {
Nonce string
CodeChallenge *OIDCCodeChallenge
passwordChecked bool
authTime time.Time
done bool
authTime time.Time
}
func (a *AuthRequest) GetID() string {
@ -51,7 +51,7 @@ func (a *AuthRequest) GetACR() string {
func (a *AuthRequest) GetAMR() []string {
// this example only uses password for authentication
if a.passwordChecked {
if a.done {
return []string{"pwd"}
}
return nil
@ -102,7 +102,7 @@ func (a *AuthRequest) GetSubject() string {
}
func (a *AuthRequest) Done() bool {
return a.passwordChecked // this example only uses password for authentication
return a.done
}
func PromptToInternal(oidcPrompt oidc.SpaceDelimitedArray) []string {

View file

@ -28,8 +28,8 @@ var serviceKey1 = &rsa.PublicKey{
E: 65537,
}
// var _ op.Storage = &storage{}
// var _ op.ClientCredentialsStorage = &storage{}
var _ op.Storage = &Storage{}
var _ op.ClientCredentialsStorage = &Storage{}
// storage implements the op.Storage interface
// typically you would implement this as a layer on top of your database
@ -46,6 +46,7 @@ type Storage struct {
signingKey signingKey
deviceCodes map[string]deviceAuthorizationEntry
userCodes map[string]string
serviceUsers map[string]*Client
}
type signingKey struct {
@ -109,6 +110,16 @@ func NewStorage(userStore UserStore) *Storage {
},
deviceCodes: make(map[string]deviceAuthorizationEntry),
userCodes: make(map[string]string),
serviceUsers: map[string]*Client{
"sid1": {
id: "sid1",
secret: "verysecret",
grantTypes: []oidc.GrantType{
oidc.GrantTypeClientCredentials,
},
accessTokenType: op.AccessTokenTypeBearer,
},
},
}
}
@ -133,7 +144,7 @@ func (s *Storage) CheckUsernamePassword(username, password, id string) error {
// you will have to change some state on the request to guide the user through possible multiple steps of the login process
// in this example we'll simply check the username / password and set a boolean to true
// therefore we will also just check this boolean if the request / login has been finished
request.passwordChecked = true
request.done = true
return nil
}
return fmt.Errorf("username or password wrong")
@ -847,3 +858,44 @@ func (s *Storage) DenyDeviceAuthorization(ctx context.Context, userCode string)
s.deviceCodes[s.userCodes[userCode]].state.Denied = true
return nil
}
// AuthRequestDone is used by testing and is not required to implement op.Storage
func (s *Storage) AuthRequestDone(id string) error {
s.lock.Lock()
defer s.lock.Unlock()
if req, ok := s.authRequests[id]; ok {
req.done = true
return nil
}
return errors.New("request not found")
}
func (s *Storage) ClientCredentials(ctx context.Context, clientID, clientSecret string) (op.Client, error) {
s.lock.Lock()
defer s.lock.Unlock()
client, ok := s.serviceUsers[clientID]
if !ok {
return nil, errors.New("wrong service user or password")
}
if client.secret != clientSecret {
return nil, errors.New("wrong service user or password")
}
return client, nil
}
func (s *Storage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scopes []string) (op.TokenRequest, error) {
client, ok := s.serviceUsers[clientID]
if !ok {
return nil, errors.New("wrong service user or password")
}
return &oidc.JWTTokenRequest{
Subject: client.id,
Audience: []string{clientID},
Scopes: scopes,
}, nil
}

2
go.mod
View file

@ -10,7 +10,7 @@ require (
github.com/gorilla/schema v1.2.0
github.com/gorilla/securecookie v1.1.1
github.com/jeremija/gosubmit v0.2.7
github.com/muhlemmer/gu v0.3.0
github.com/muhlemmer/gu v0.3.1
github.com/rs/cors v1.8.3
github.com/sirupsen/logrus v1.9.0
github.com/stretchr/testify v1.8.1

4
go.sum
View file

@ -123,8 +123,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/muhlemmer/gu v0.3.0 h1:UwNv9xXGp1WDgHKgk7ljjh3duh1w4ZAY1k1NsWBYl3Y=
github.com/muhlemmer/gu v0.3.0/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM=
github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View file

@ -7,16 +7,16 @@ import (
type key int
var (
issuer key = 0
const (
issuerKey key = 0
)
type IssuerInterceptor struct {
issuerFromRequest IssuerFromRequest
}
//NewIssuerInterceptor will set the issuer into the context
//by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
// NewIssuerInterceptor will set the issuer into the context
// by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
return &IssuerInterceptor{
issuerFromRequest: issuerFromRequest,
@ -35,15 +35,19 @@ func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc
}
}
//IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
//it will return an empty string if not found
// IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
// it will return an empty string if not found
func IssuerFromContext(ctx context.Context) string {
ctxIssuer, _ := ctx.Value(issuer).(string)
ctxIssuer, _ := ctx.Value(issuerKey).(string)
return ctxIssuer
}
// ContextWithIssuer returns a new context with issuer set to it.
func ContextWithIssuer(ctx context.Context, issuer string) context.Context {
return context.WithValue(ctx, issuerKey, issuer)
}
func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
ctx := context.WithValue(r.Context(), issuer, i.issuerFromRequest(r))
r = r.WithContext(ctx)
r = r.WithContext(ContextWithIssuer(r.Context(), i.issuerFromRequest(r)))
next.ServeHTTP(w, r)
}

View file

@ -3,7 +3,6 @@ package op_test
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"io"
mr "math/rand"
@ -16,52 +15,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserCode: op.UserCodeBase20,
},
}
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
}
func Test_deviceAuthorizationHandler(t *testing.T) {
req := &oidc.DeviceAuthorizationRequest{
Scopes: []string{"foo", "bar"},

392
pkg/op/op_test.go Normal file
View file

@ -0,0 +1,392 @@
package op_test
import (
"context"
"crypto/sha256"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/example/server/storage"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"golang.org/x/text/language"
)
var testProvider op.OpenIDProvider
const (
testIssuer = "https://localhost:9998/"
pathLoggedOut = "/logged-out"
)
func init() {
config := &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
CodeMethodS256: true,
AuthMethodPost: true,
AuthMethodPrivateKeyJWT: true,
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
DeviceAuthorization: op.DeviceAuthorizationConfig{
Lifetime: 5 * time.Minute,
PollInterval: 5 * time.Second,
UserFormURL: testIssuer + "device",
UserCode: op.UserCodeBase20,
},
}
storage.RegisterClients(
storage.NativeClient("native"),
storage.WebClient("web", "secret", "https://example.com"),
storage.WebClient("api", "secret"),
)
var err error
testProvider, err = op.NewOpenIDProvider(testIssuer, config,
storage.NewStorage(storage.NewUserStore(testIssuer)), op.WithAllowInsecure(),
)
if err != nil {
panic(err)
}
}
type routesTestStorage interface {
op.Storage
AuthRequestDone(id string) error
}
func mapAsValues(m map[string]string) string {
values := make(url.Values, len(m))
for k, v := range m {
values.Set(k, v)
}
return values.Encode()
}
func TestRoutes(t *testing.T) {
storage := testProvider.Storage().(routesTestStorage)
ctx := op.ContextWithIssuer(context.Background(), testIssuer)
client, err := storage.GetClientByClientID(ctx, "web")
require.NoError(t, err)
oidcAuthReq := &oidc.AuthRequest{
ClientID: client.GetID(),
RedirectURI: "https://example.com",
MaxAge: gu.Ptr[uint](300),
Scopes: oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess, oidc.ScopeEmail, oidc.ScopeProfile, oidc.ScopePhone},
ResponseType: oidc.ResponseTypeCode,
}
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
accessTokenRevoke, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
idToken, err := op.CreateIDToken(ctx, testIssuer, authReq, time.Hour, accessToken, "123", storage, client)
require.NoError(t, err)
jwtToken, _, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeJWT, testProvider, client, "")
require.NoError(t, err)
oidcAuthReq.IDTokenHint = idToken
serverURL, err := url.Parse(testIssuer)
require.NoError(t, err)
type basicAuth struct {
username, password string
}
tests := []struct {
name string
method string
path string
basicAuth *basicAuth
header map[string]string
values map[string]string
body map[string]string
wantCode int
headerContains map[string]string
json string // test for exact json output
contains []string // when the body output is not constant, we just check for snippets to be present in the response
}{
{
name: "health",
method: http.MethodGet,
path: "/healthz",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "ready",
method: http.MethodGet,
path: "/ready",
wantCode: http.StatusOK,
json: `{"status":"ok"}`,
},
{
name: "discovery",
method: http.MethodGet,
path: oidc.DiscoveryEndpoint,
wantCode: http.StatusOK,
json: `{"issuer":"https://localhost:9998/","authorization_endpoint":"https://localhost:9998/authorize","token_endpoint":"https://localhost:9998/oauth/token","introspection_endpoint":"https://localhost:9998/oauth/introspect","userinfo_endpoint":"https://localhost:9998/userinfo","revocation_endpoint":"https://localhost:9998/revoke","end_session_endpoint":"https://localhost:9998/end_session","device_authorization_endpoint":"https://localhost:9998/device_authorization","jwks_uri":"https://localhost:9998/keys","scopes_supported":["openid","profile","email","phone","address","offline_access"],"response_types_supported":["code","id_token","id_token token"],"grant_types_supported":["authorization_code","implicit","refresh_token","client_credentials","urn:ietf:params:oauth:grant-type:token-exchange","urn:ietf:params:oauth:grant-type:jwt-bearer","urn:ietf:params:oauth:grant-type:device_code"],"subject_types_supported":["public"],"id_token_signing_alg_values_supported":["RS256"],"request_object_signing_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"token_endpoint_auth_signing_alg_values_supported":["RS256"],"revocation_endpoint_auth_methods_supported":["none","client_secret_basic","client_secret_post","private_key_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["RS256"],"introspection_endpoint_auth_methods_supported":["client_secret_basic","private_key_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["RS256"],"claims_supported":["sub","aud","exp","iat","iss","auth_time","nonce","acr","amr","c_hash","at_hash","act","scopes","client_id","azp","preferred_username","name","family_name","given_name","locale","email","email_verified","phone_number","phone_number_verified"],"code_challenge_methods_supported":["S256"],"ui_locales_supported":["en"],"request_parameter_supported":true,"request_uri_parameter_supported":false}`,
},
{
name: "authorization",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative(),
values: map[string]string{
"client_id": client.GetID(),
"redirect_uri": "https://example.com",
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"response_type": string(oidc.ResponseTypeCode),
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/login/username?authRequestID="},
},
{
name: "authorization callback",
method: http.MethodGet,
path: testProvider.AuthorizationEndpoint().Relative() + "/callback",
values: map[string]string{"id": authReq.GetID()},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "https://example.com?code="},
contains: []string{
`<a href="https://example.com?code=`,
">Found</a>.",
},
},
{
// This call will fail. A successfull test is already
// part of client/integration_test.go
name: "code exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeCode),
"code": "123",
},
wantCode: http.StatusUnauthorized,
json: `{"error":"invalid_client"}`,
},
{
name: "JWT authorization",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeBearer),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"assertion": jwtToken,
},
wantCode: http.StatusBadRequest,
json: "{\"error\":\"server_error\",\"error_description\":\"audience is not valid: Audience must contain client_id \\\"https://localhost:9998/\\\"\"}",
},
{
name: "Token exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeTokenExchange),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
"subject_token": jwtToken,
"subject_token_type": string(oidc.AccessTokenType),
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","issued_token_type":"urn:ietf:params:oauth:token-type:refresh_token","token_type":"Bearer","expires_in":299,"scope":"openid offline_access","refresh_token":"`,
},
},
{
name: "Client credentials exchange",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"sid1", "verysecret"},
values: map[string]string{
"grant_type": string(oidc.GrantTypeClientCredentials),
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
},
{
// This call will fail. A successfull test is already
// part of device_test.go
name: "device token",
method: http.MethodPost,
path: testProvider.TokenEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
header: map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
},
body: map[string]string{
"grant_type": string(oidc.GrantTypeDeviceCode),
"device_code": "123",
},
wantCode: http.StatusBadRequest,
json: `{"error":"access_denied","error_description":"The authorization request was denied."}`,
},
{
name: "missing grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_request","error_description":"grant_type missing"}`,
},
{
name: "unsupported grant type",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": "foo",
},
wantCode: http.StatusBadRequest,
json: `{"error":"unsupported_grant_type","error_description":"foo not supported"}`,
},
{
name: "introspection",
method: http.MethodGet,
path: testProvider.IntrospectionEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessToken,
},
wantCode: http.StatusOK,
json: `{"active":true,"scope":"openid offline_access email profile phone","client_id":"web","sub":"id1","username":"test-user@localhost","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "user info",
method: http.MethodGet,
path: testProvider.UserinfoEndpoint().Relative(),
header: map[string]string{
"authorization": "Bearer " + accessToken,
},
wantCode: http.StatusOK,
json: `{"sub":"id1","name":"Test User","given_name":"Test","family_name":"User","locale":"de","preferred_username":"test-user@localhost","email":"test-user@zitadel.ch","email_verified":true}`,
},
{
name: "refresh token",
method: http.MethodGet,
path: testProvider.TokenEndpoint().Relative(),
values: map[string]string{
"grant_type": string(oidc.GrantTypeRefreshToken),
"refresh_token": refreshToken,
"client_id": client.GetID(),
"client_secret": "secret",
},
wantCode: http.StatusOK,
contains: []string{
`{"access_token":"`,
`","token_type":"Bearer","refresh_token":"`,
`","expires_in":299,"id_token":"`,
},
},
{
name: "revoke",
method: http.MethodGet,
path: testProvider.RevocationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"token": accessTokenRevoke,
},
wantCode: http.StatusOK,
},
{
name: "end session",
method: http.MethodGet,
path: testProvider.EndSessionEndpoint().Relative(),
values: map[string]string{
"id_token_hint": idToken,
"client_id": "web",
},
wantCode: http.StatusFound,
headerContains: map[string]string{"Location": "/logged-out"},
contains: []string{`<a href="/logged-out">Found</a>.`},
},
{
name: "keys",
method: http.MethodGet,
path: testProvider.KeysEndpoint().Relative(),
wantCode: http.StatusOK,
contains: []string{
`{"keys":[{"use":"sig","kty":"RSA","kid":"`,
`","alg":"RS256","n":"`, `","e":"AQAB"}]}`,
},
},
{
name: "device authorization",
method: http.MethodGet,
path: testProvider.DeviceAuthorizationEndpoint().Relative(),
basicAuth: &basicAuth{"web", "secret"},
values: map[string]string{
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.Encode(),
},
wantCode: http.StatusOK,
contains: []string{
`{"device_code":"`, `","user_code":"`,
`","verification_uri":"https://localhost:9998/device"`,
`"verification_uri_complete":"https://localhost:9998/device?user_code=`,
`","expires_in":300,"interval":5}`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u := gu.PtrCopy(serverURL)
u.Path = tt.path
if tt.values != nil {
u.RawQuery = mapAsValues(tt.values)
}
var body io.Reader
if tt.body != nil {
body = strings.NewReader(mapAsValues(tt.body))
}
req := httptest.NewRequest(tt.method, u.String(), body)
for k, v := range tt.header {
req.Header.Set(k, v)
}
if tt.basicAuth != nil {
req.SetBasicAuth(tt.basicAuth.username, tt.basicAuth.password)
}
rec := httptest.NewRecorder()
testProvider.HttpHandler().ServeHTTP(rec, req)
resp := rec.Result()
require.NoError(t, err)
assert.Equal(t, tt.wantCode, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respBodyString := string(respBody)
t.Log(respBodyString)
t.Log(resp.Header)
if tt.json != "" {
assert.JSONEq(t, tt.json, respBodyString)
}
for _, c := range tt.contains {
assert.Contains(t, respBodyString, c)
}
for k, v := range tt.headerContains {
assert.Contains(t, resp.Header.Get(k), v)
}
})
}
}