diff --git a/common/rpc/encryption/fixedTLSConfigProvider.go b/common/rpc/encryption/fixedTLSConfigProvider.go index adeb01a55a..1d6819ff1c 100644 --- a/common/rpc/encryption/fixedTLSConfigProvider.go +++ b/common/rpc/encryption/fixedTLSConfigProvider.go @@ -41,7 +41,11 @@ func (f *FixedTLSConfigProvider) GetFrontendClientConfig() (*tls.Config, error) // GetRemoteClusterClientConfig implements [TLSConfigProvider.GetRemoteClusterClientConfig]. func (f *FixedTLSConfigProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) { - return f.RemoteClusterClientConfigs[hostname], nil + if cfg, ok := f.RemoteClusterClientConfigs[hostname]; ok { + return cfg, nil + } + // Fall back to default config + return f.RemoteClusterClientConfigs[defaultRemoteCluster], nil } // GetExpiringCerts implements [TLSConfigProvider.GetExpiringCerts]. diff --git a/common/rpc/encryption/local_store_tls_provider.go b/common/rpc/encryption/local_store_tls_provider.go index dab35c79fa..65ff08ee05 100644 --- a/common/rpc/encryption/local_store_tls_provider.go +++ b/common/rpc/encryption/local_store_tls_provider.go @@ -48,6 +48,8 @@ type localStoreTlsProvider struct { var _ TLSConfigProvider = (*localStoreTlsProvider)(nil) var _ CertExpirationChecker = (*localStoreTlsProvider)(nil) +const defaultRemoteCluster = "default" + func NewLocalStoreTlsProvider(tlsConfig *config.RootTLS, metricsHandler metrics.Handler, logger log.Logger, certProviderFactory CertProviderFactory, ) (TLSConfigProvider, error) { @@ -139,16 +141,22 @@ func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) { } func (s *localStoreTlsProvider) GetRemoteClusterClientConfig(hostname string) (*tls.Config, error) { + certProviderKey := hostname groupTLS, ok := s.settings.RemoteClusters[hostname] if !ok { - return nil, nil + // Fall back to default/wildcard config if present + groupTLS, ok = s.settings.RemoteClusters[defaultRemoteCluster] + if !ok { + return nil, nil + } + certProviderKey = defaultRemoteCluster } return s.getOrCreateRemoteClusterClientConfig( hostname, func() (*tls.Config, error) { return newClientTLSConfig( - s.remoteClusterClientCertProvider[hostname], + s.remoteClusterClientCertProvider[certProviderKey], groupTLS.Client.ServerName, groupTLS.Server.RequireClientAuth, false, diff --git a/common/rpc/encryption/tls_config_test.go b/common/rpc/encryption/tls_config_test.go index 12d728be96..f2924781a1 100644 --- a/common/rpc/encryption/tls_config_test.go +++ b/common/rpc/encryption/tls_config_test.go @@ -1,11 +1,16 @@ package encryption import ( + "crypto/tls" + "crypto/x509" "testing" + "time" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/metrics" ) type ( @@ -218,3 +223,101 @@ func (s *tlsConfigTest) TestSystemWorkerTLSConfig() { client.RootCAData = []string{""} s.Error(validateRootTLS(cfg)) } + +// stubCertProvider is a no-op CertProvider for use in unit tests. +type stubCertProvider struct{} + +func (s *stubCertProvider) FetchServerCertificate() (*tls.Certificate, error) { return nil, nil } +func (s *stubCertProvider) FetchClientCAs() (*x509.CertPool, error) { return nil, nil } +func (s *stubCertProvider) FetchClientCertificate(_ bool) (*tls.Certificate, error) { + return nil, nil +} +func (s *stubCertProvider) FetchServerRootCAsForClient(_ bool) (*x509.CertPool, error) { + return nil, nil +} +func (s *stubCertProvider) GetExpiringCerts(_ time.Duration) (expiring CertExpirationMap, expired CertExpirationMap, err error) { + return nil, nil, nil +} + +func stubCertProviderFactory(_ *config.GroupTLS, _ *config.WorkerTLS, _ *config.ClientTLS, _ time.Duration, _ log.Logger) CertProvider { + return &stubCertProvider{} +} + +func newTestTLSProvider(t *testing.T, cfg config.RootTLS) TLSConfigProvider { + t.Helper() + provider, err := NewLocalStoreTlsProvider(&cfg, metrics.NoopMetricsHandler, log.NewTestLogger(), stubCertProviderFactory) + require.NoError(t, err) + return provider +} + +func TestGetRemoteClusterClientConfig_NoConfig(t *testing.T) { + provider := newTestTLSProvider(t, config.RootTLS{}) + tlsCfg, err := provider.GetRemoteClusterClientConfig("some-host") + require.NoError(t, err) + require.Nil(t, tlsCfg) +} + +func TestGetRemoteClusterClientConfig_UnknownHostNoDefault(t *testing.T) { + cfg := config.RootTLS{ + RemoteClusters: map[string]config.GroupTLS{ + "cluster-a.example.com": {Client: config.ClientTLS{ForceTLS: true}}, + }, + } + provider := newTestTLSProvider(t, cfg) + + tlsCfg, err := provider.GetRemoteClusterClientConfig("unknown-host.example.com") + require.NoError(t, err) + require.Nil(t, tlsCfg) +} + +func TestGetRemoteClusterClientConfig_ExactMatch(t *testing.T) { + cfg := config.RootTLS{ + RemoteClusters: map[string]config.GroupTLS{ + "cluster-a.example.com": {Client: config.ClientTLS{ForceTLS: true}}, + }, + } + provider := newTestTLSProvider(t, cfg) + + tlsCfg, err := provider.GetRemoteClusterClientConfig("cluster-a.example.com") + require.NoError(t, err) + require.NotNil(t, tlsCfg) + + // Unknown host with no default → nil + tlsCfg, err = provider.GetRemoteClusterClientConfig("cluster-b.example.com") + require.NoError(t, err) + require.Nil(t, tlsCfg) +} + +func TestGetRemoteClusterClientConfig_DefaultFallback(t *testing.T) { + cfg := config.RootTLS{ + RemoteClusters: map[string]config.GroupTLS{ + defaultRemoteCluster: {Client: config.ClientTLS{ForceTLS: true}}, + }, + } + provider := newTestTLSProvider(t, cfg) + + tlsCfg, err := provider.GetRemoteClusterClientConfig("any-unknown-host") + require.NoError(t, err) + require.NotNil(t, tlsCfg) +} + +func TestGetRemoteClusterClientConfig_ExactMatchTakesPriority(t *testing.T) { + cfg := config.RootTLS{ + RemoteClusters: map[string]config.GroupTLS{ + "cluster-a.example.com": {Client: config.ClientTLS{ForceTLS: false}}, + // Default has ForceTLS: false so IsClientEnabled() returns false → nil config + defaultRemoteCluster: {Client: config.ClientTLS{ForceTLS: true}}, + }, + } + provider := newTestTLSProvider(t, cfg) + + // Exact match → nil (ForceTLS: false) + tlsCfg, err := provider.GetRemoteClusterClientConfig("cluster-a.example.com") + require.NoError(t, err) + require.Nil(t, tlsCfg) + + // Unknown host falls back to default (ForceTLS: true) → non-nil + tlsCfg, err = provider.GetRemoteClusterClientConfig("unknown-host") + require.NoError(t, err) + require.NotNil(t, tlsCfg) +}