Merge branch 'main' into fix-empty-locale

This commit is contained in:
Tim Möhlmann 2025-03-12 12:54:09 +02:00 committed by GitHub
commit 7dcd1ec03c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 284 additions and 143 deletions

View file

@ -2,7 +2,7 @@ package client
import (
"encoding/json"
"io/ioutil"
"os"
)
const (
@ -24,7 +24,7 @@ type KeyFile struct {
}
func ConfigFromKeyFile(path string) (*KeyFile, error) {
data, err := ioutil.ReadFile(path)
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}

View file

@ -133,6 +133,7 @@ type Error struct {
ErrorType errorType `json:"error" schema:"error"`
Description string `json:"error_description,omitempty" schema:"error_description,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
SessionState string `json:"session_state,omitempty" schema:"session_state,omitempty"`
redirectDisabled bool `schema:"-"`
returnParent bool `schema:"-"`
}
@ -142,11 +143,13 @@ func (e *Error) MarshalJSON() ([]byte, error) {
Error errorType `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
State string `json:"state,omitempty"`
SessionState string `json:"session_state,omitempty"`
Parent string `json:"parent,omitempty"`
}{
Error: e.ErrorType,
ErrorDescription: e.Description,
State: e.State,
SessionState: e.SessionState,
}
if e.returnParent {
m.Parent = e.Parent.Error()
@ -176,7 +179,8 @@ func (e *Error) Is(target error) bool {
}
return e.ErrorType == t.ErrorType &&
(e.Description == t.Description || t.Description == "") &&
(e.State == t.State || t.State == "")
(e.State == t.State || t.State == "") &&
(e.SessionState == t.SessionState || t.SessionState == "")
}
func (e *Error) WithParent(err error) *Error {
@ -242,6 +246,9 @@ func (e *Error) LogValue() slog.Value {
if e.State != "" {
attrs = append(attrs, slog.String("state", e.State))
}
if e.SessionState != "" {
attrs = append(attrs, slog.String("session_state", e.SessionState))
}
if e.redirectDisabled {
attrs = append(attrs, slog.Bool("redirect_disabled", e.redirectDisabled))
}

View file

@ -230,12 +230,13 @@ func (c *ActorClaims) UnmarshalJSON(data []byte) error {
}
type AccessTokenResponse struct {
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
State string `json:"state,omitempty" schema:"state,omitempty"`
Scope SpaceDelimitedArray `json:"scope,omitempty" schema:"scope,omitempty"`
}
type JWTProfileAssertionClaims struct {

View file

@ -7,12 +7,11 @@ import (
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
jose "github.com/go-jose/go-jose/v4"
str "github.com/zitadel/oidc/v3/pkg/strings"
)
type Claims interface {
@ -84,7 +83,7 @@ type ACRVerifier func(string) error
// if none of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
return func(acr string) error {
if !str.Contains(possibleValues, acr) {
if !slices.Contains(possibleValues, acr) {
return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
}
return nil
@ -123,7 +122,7 @@ func CheckIssuer(claims Claims, issuer string) error {
}
func CheckAudience(claims Claims, clientID string) error {
if !str.Contains(claims.GetAudience(), clientID) {
if !slices.Contains(claims.GetAudience(), clientID) {
return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
}

View file

@ -18,7 +18,6 @@ import (
"github.com/bmatcuk/doublestar/v4"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
str "github.com/zitadel/oidc/v3/pkg/strings"
)
type AuthRequest interface {
@ -39,6 +38,13 @@ type AuthRequest interface {
Done() bool
}
// AuthRequestSessionState should be implemented if [OpenID Connect Session Management](https://openid.net/specs/openid-connect-session-1_0.html) is supported
type AuthRequestSessionState interface {
// GetSessionState returns session_state.
// session_state is related to OpenID Connect Session Management.
GetSessionState() string
}
type Authorizer interface {
Storage() Storage
Decoder() httphelper.Decoder
@ -104,8 +110,8 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
}
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
}
if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
if validator, ok := authorizer.(AuthorizeValidator); ok {
validation = validator.ValidateAuthRequest
}
userID, err := validation(ctx, authReq, authorizer.Storage(), authorizer.IDTokenHintVerifier(ctx))
if err != nil {
@ -156,7 +162,7 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage
if requestObject.Issuer != requestObject.ClientID {
return oidc.ErrInvalidRequest().WithDescription("missing or wrong issuer in request")
}
if !str.Contains(requestObject.Audience, issuer) {
if !slices.Contains(requestObject.Audience, issuer) {
return oidc.ErrInvalidRequest().WithDescription("issuer missing in audience")
}
keySet := &jwtProfileKeySet{storage: storage, clientID: requestObject.Issuer}
@ -170,7 +176,7 @@ func ParseRequestObject(ctx context.Context, authReq *oidc.AuthRequest, storage
// CopyRequestObjectToAuthRequest overwrites present values from the Request Object into the auth request
// and clears the `RequestParam` of the auth request
func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oidc.RequestObject) {
if str.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 {
if slices.Contains(authReq.Scopes, oidc.ScopeOpenID) && len(requestObject.Scopes) > 0 {
authReq.Scopes = requestObject.Scopes
}
if requestObject.RedirectURI != "" {
@ -288,7 +294,7 @@ func ValidateAuthReqScopes(client Client, scopes []string) ([]string, error) {
// checkURIAgainstRedirects just checks aginst the valid redirect URIs and ignores
// other factors.
func checkURIAgainstRedirects(client Client, uri string) error {
if str.Contains(client.RedirectURIs(), uri) {
if slices.Contains(client.RedirectURIs(), uri) {
return nil
}
if globClient, ok := client.(HasRedirectGlobs); ok {
@ -313,12 +319,12 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
return oidc.ErrInvalidRequestRedirectURI().WithDescription("The redirect_uri is missing in the request. " +
"Please ensure it is added to the request. If you have any questions, you may contact the administrator of the application.")
}
if strings.HasPrefix(uri, "https://") {
return checkURIAgainstRedirects(client, uri)
}
if client.ApplicationType() == ApplicationTypeNative {
return validateAuthReqRedirectURINative(client, uri)
}
if strings.HasPrefix(uri, "https://") {
return checkURIAgainstRedirects(client, uri)
}
if err := checkURIAgainstRedirects(client, uri); err != nil {
return err
}
@ -339,12 +345,15 @@ func ValidateAuthReqRedirectURI(client Client, uri string, responseType oidc.Res
// ValidateAuthReqRedirectURINative validates the passed redirect_uri and response_type to the registered uris and client type
func validateAuthReqRedirectURINative(client Client, uri string) error {
parsedURL, isLoopback := HTTPLoopbackOrLocalhost(uri)
isCustomSchema := !strings.HasPrefix(uri, "http://")
isCustomSchema := !(strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://"))
if err := checkURIAgainstRedirects(client, uri); err == nil {
if client.DevMode() {
return nil
}
// The RedirectURIs are only valid for native clients when localhost or non-"http://"
if !isLoopback && strings.HasPrefix(uri, "https://") {
return nil
}
// The RedirectURIs are only valid for native clients when localhost or non-"http://" and "https://"
if isLoopback || isCustomSchema {
return nil
}
@ -374,11 +383,11 @@ func HTTPLoopbackOrLocalhost(rawURL string) (*url.URL, bool) {
if err != nil {
return nil, false
}
if parsedURL.Scheme != "http" {
return nil, false
if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
hostName := parsedURL.Hostname()
return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback()
}
hostName := parsedURL.Hostname()
return parsedURL, hostName == "localhost" || net.ParseIP(hostName).IsLoopback()
return nil, false
}
// ValidateAuthReqResponseType validates the passed response_type to the registered response types
@ -479,12 +488,19 @@ func AuthResponseCode(w http.ResponseWriter, r *http.Request, authReq AuthReques
AuthRequestError(w, r, authReq, err, authorizer)
return
}
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
codeResponse := struct {
Code string `schema:"code"`
State string `schema:"state,omitempty"`
Code string `schema:"code"`
State string `schema:"state,omitempty"`
SessionState string `schema:"session_state,omitempty"`
}{
Code: code,
State: authReq.GetState(),
Code: code,
State: authReq.GetState(),
SessionState: sessionState,
}
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {

View file

@ -433,6 +433,24 @@ func TestValidateAuthReqRedirectURI(t *testing.T) {
},
false,
},
{
"code flow registered https loopback v4 native ok",
args{
"https://127.0.0.1:4200/callback",
mock.NewClientWithConfig(t, []string{"https://127.0.0.1/callback"}, op.ApplicationTypeNative, nil, false),
oidc.ResponseTypeCode,
},
false,
},
{
"code flow registered https loopback v6 native ok",
args{
"https://[::1]:4200/callback",
mock.NewClientWithConfig(t, []string{"https://[::1]/callback"}, op.ApplicationTypeNative, nil, false),
oidc.ResponseTypeCode,
},
false,
},
{
"code flow unregistered http native fails",
args{
@ -1072,6 +1090,34 @@ func TestAuthResponseCode(t *testing.T) {
wantBody: "",
},
},
{
name: "success with state and session_state",
args: args{
authReq: &storage.AuthRequestWithSessionState{
AuthRequest: &storage.AuthRequest{
ID: "id1",
TransferState: "state1",
},
SessionState: "session_state1",
},
authorizer: func(t *testing.T) op.Authorizer {
ctrl := gomock.NewController(t)
storage := mock.NewMockStorage(ctrl)
storage.EXPECT().SaveAuthCode(gomock.Any(), "id1", "id1")
authorizer := mock.NewMockAuthorizer(ctrl)
authorizer.EXPECT().Storage().Return(storage)
authorizer.EXPECT().Crypto().Return(&mockCrypto{})
authorizer.EXPECT().Encoder().Return(schema.NewEncoder())
return authorizer
},
},
res: res{
wantCode: http.StatusFound,
wantLocationHeader: "/auth/callback/?code=id1&session_state=session_state1&state=state1",
wantBody: "",
},
},
{
name: "success without state", // reproduce issue #415
args: args{

View file

@ -30,6 +30,7 @@ type Configuration interface {
EndSessionEndpoint() *Endpoint
KeysEndpoint() *Endpoint
DeviceAuthorizationEndpoint() *Endpoint
CheckSessionIframe() *Endpoint
AuthMethodPostSupported() bool
CodeMethodS256Supported() bool

View file

@ -9,12 +9,12 @@ import (
"math/big"
"net/http"
"net/url"
"slices"
"strings"
"time"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
strs "github.com/zitadel/oidc/v3/pkg/strings"
)
type DeviceAuthorizationConfig struct {
@ -276,7 +276,7 @@ func (r *DeviceAuthorizationState) GetAMR() []string {
}
func (r *DeviceAuthorizationState) GetAudience() []string {
if !strs.Contains(r.Audience, r.ClientID) {
if !slices.Contains(r.Audience, r.ClientID) {
r.Audience = append(r.Audience, r.ClientID)
}
return r.Audience
@ -344,10 +344,11 @@ func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, c
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
Scope: tokenRequest.GetScopes(),
}
// TODO(v4): remove type assertion
if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && strs.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && slices.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client)
if err != nil {
return nil, err

View file

@ -45,6 +45,7 @@ func CreateDiscoveryConfig(ctx context.Context, config Configuration, storage Di
EndSessionEndpoint: config.EndSessionEndpoint().Absolute(issuer),
JwksURI: config.KeysEndpoint().Absolute(issuer),
DeviceAuthorizationEndpoint: config.DeviceAuthorizationEndpoint().Absolute(issuer),
CheckSessionIframe: config.CheckSessionIframe().Absolute(issuer),
ScopesSupported: Scopes(config),
ResponseTypesSupported: ResponseTypes(config),
GrantTypesSupported: GrantTypes(config),
@ -100,7 +101,11 @@ func createDiscoveryConfigV2(ctx context.Context, config Configuration, storage
}
func Scopes(c Configuration) []string {
return DefaultSupportedScopes // TODO: config
provider, ok := c.(*Provider)
if ok && provider.config.SupportedScopes != nil {
return provider.config.SupportedScopes
}
return DefaultSupportedScopes
}
func ResponseTypes(c Configuration) []string {
@ -135,7 +140,7 @@ func GrantTypes(c Configuration) []oidc.GrantType {
}
func SubjectTypes(c Configuration) []string {
return []string{"public"} //TODO: config
return []string{"public"} // TODO: config
}
func SigAlgorithms(ctx context.Context, storage DiscoverStorage) []string {

View file

@ -81,6 +81,11 @@ func Test_scopes(t *testing.T) {
args{},
op.DefaultSupportedScopes,
},
{
"custom scopes",
args{newTestProvider(&op.Config{SupportedScopes: []string{"test1", "test2"}})},
[]string{"test1", "test2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -46,6 +46,12 @@ func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthReq
return
}
e.State = authReq.GetState()
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
e.SessionState = sessionState
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()
@ -92,6 +98,12 @@ func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error,
}
e.State = authReq.GetState()
var sessionState string
authRequestSessionState, ok := authReq.(AuthRequestSessionState)
if ok {
sessionState = authRequestSessionState.GetSessionState()
}
e.SessionState = sessionState
var responseMode oidc.ResponseMode
if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
responseMode = rm.GetResponseMode()

View file

@ -106,6 +106,20 @@ func (mr *MockConfigurationMockRecorder) BackChannelLogoutSupported() *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackChannelLogoutSupported", reflect.TypeOf((*MockConfiguration)(nil).BackChannelLogoutSupported))
}
// CheckSessionIframe mocks base method.
func (m *MockConfiguration) CheckSessionIframe() *op.Endpoint {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckSessionIframe")
ret0, _ := ret[0].(*op.Endpoint)
return ret0
}
// CheckSessionIframe indicates an expected call of CheckSessionIframe.
func (mr *MockConfigurationMockRecorder) CheckSessionIframe() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckSessionIframe", reflect.TypeOf((*MockConfiguration)(nil).CheckSessionIframe))
}
// CodeMethodS256Supported mocks base method.
func (m *MockConfiguration) CodeMethodS256Supported() bool {
m.ctrl.T.Helper()

View file

@ -167,6 +167,7 @@ type Config struct {
RequestObjectSupported bool
SupportedUILocales []language.Tag
SupportedClaims []string
SupportedScopes []string
DeviceAuthorization DeviceAuthorizationConfig
BackChannelLogoutSupported bool
BackChannelLogoutSessionSupported bool
@ -338,6 +339,10 @@ func (o *Provider) DeviceAuthorizationEndpoint() *Endpoint {
return o.endpoints.DeviceAuthorization
}
func (o *Provider) CheckSessionIframe() *Endpoint {
return o.endpoints.CheckSessionIframe
}
func (o *Provider) KeysEndpoint() *Endpoint {
return o.endpoints.JwksURI
}

View file

@ -232,7 +232,7 @@ func TestRoutes(t *testing.T) {
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`},
},
{
// This call will fail. A successful test is already

View file

@ -145,7 +145,7 @@ func TestServerRoutes(t *testing.T) {
"assertion": jwtProfileToken,
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299}`},
contains: []string{`{"access_token":`, `"token_type":"Bearer","expires_in":299,"scope":"openid"}`},
},
{
name: "Token exchange",
@ -174,7 +174,7 @@ func TestServerRoutes(t *testing.T) {
"scope": oidc.SpaceDelimitedArray{oidc.ScopeOpenID, oidc.ScopeOfflineAccess}.String(),
},
wantCode: http.StatusOK,
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299}`},
contains: []string{`{"access_token":"`, `","token_type":"Bearer","expires_in":299,"scope":"openid offline_access"}`},
},
{
// This call will fail. A successful test is already

View file

@ -2,11 +2,11 @@ package op
import (
"context"
"slices"
"time"
"github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/strings"
)
type TokenCreator interface {
@ -65,6 +65,7 @@ func CreateTokenResponse(ctx context.Context, request IDTokenRequest, client Cli
TokenType: oidc.BearerToken,
ExpiresIn: exp,
State: state,
Scope: request.GetScopes(),
}, nil
}
@ -82,13 +83,13 @@ func createTokens(ctx context.Context, tokenRequest TokenRequest, storage Storag
func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool {
switch req := tokenRequest.(type) {
case AuthRequest:
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && req.GetResponseType() == oidc.ResponseTypeCode && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
case TokenExchangeRequest:
return req.GetRequestedTokenType() == oidc.RefreshTokenType
case RefreshTokenRequest:
return true
case *DeviceAuthorizationState:
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
return slices.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
default:
return false
}

View file

@ -120,5 +120,6 @@ func CreateClientCredentialsTokenResponse(ctx context.Context, tokenRequest Toke
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
Scope: tokenRequest.GetScopes(),
}, nil
}

View file

@ -89,6 +89,7 @@ func CreateJWTTokenResponse(ctx context.Context, tokenRequest TokenRequest, crea
AccessToken: accessToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
Scope: tokenRequest.GetScopes(),
}, nil
}

View file

@ -4,11 +4,11 @@ import (
"context"
"errors"
"net/http"
"slices"
"time"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/strings"
)
type RefreshTokenRequest interface {
@ -85,7 +85,7 @@ func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTok
return nil
}
for _, scope := range requestedScopes {
if !strings.Contains(authRequest.GetScopes(), scope) {
if !slices.Contains(authRequest.GetScopes(), scope) {
return oidc.ErrInvalidScope()
}
}

View file

@ -1,10 +1,9 @@
package strings
import "slices"
// Deprecated: Use standard library [slices.Contains] instead.
func Contains(list []string, needle string) bool {
for _, item := range list {
if item == needle {
return true
}
}
return false
// TODO(v4): remove package.
return slices.Contains(list, needle)
}