diff --git a/common/global_policy.go b/common/global_policy.go new file mode 100644 index 0000000..07e7d02 --- /dev/null +++ b/common/global_policy.go @@ -0,0 +1,19 @@ +package common + +import "time" + +type Policy struct { + // MuxManagerStartDelay is how long the MuxManager's Start() waits for the + // underlying MuxProvider to publish initial connections before returning. + // A value of 0 tells Start() to skip the wait entirely. + MuxManagerStartDelay time.Duration +} + +func (p *Policy) UpdateMuxManagerStartDelay(d time.Duration) { + p.MuxManagerStartDelay = d +} + +// GlobalPolicy is the process-wide policy singleton. +var GlobalPolicy = &Policy{ + MuxManagerStartDelay: time.Minute, +} diff --git a/config/config.go b/config/config.go index edf64c1..f741894 100644 --- a/config/config.go +++ b/config/config.go @@ -2,7 +2,9 @@ package config import ( "bytes" + "fmt" "os" + "time" "github.com/urfave/cli/v2" "gopkg.in/yaml.v3" @@ -48,8 +50,18 @@ type ( Logging LoggingConfig `yaml:"logging"` LogConfigs map[string]LoggingConfig `yaml:"logConfigs"` ClusterConnections []ClusterConnConfig `yaml:"clusterConnections"` + // MuxManagerStartDelay overrides the time the mux manager waits for + // initial connections before serving. Accepts Go duration strings + // (e.g. "30s", "1m", or "0s" to skip the wait entirely). When the + // field is omitted, the in-process default (time.Minute) is used. + MuxManagerStartDelay *Duration `yaml:"muxManagerStartDelay,omitempty"` } + // Duration is a time.Duration that unmarshals from YAML duration strings + // (e.g. "30s", "1m500ms") rather than the integer-nanoseconds form yaml.v3 + // would otherwise produce. + Duration time.Duration + SATranslationConfig struct { NamespaceMappings []SANamespaceMapping `yaml:"namespaceMappings"` cachedBiMap SearchAttributeTranslation @@ -201,6 +213,25 @@ func (c *ProfilingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error return unmarshal((*plain)(c)) } +func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + if s == "" { + return nil + } + parsed, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid duration %q: %w", s, err) + } + *d = Duration(parsed) + return nil +} + +// AsDuration returns the value as a standard time.Duration. +func (d Duration) AsDuration() time.Duration { return time.Duration(d) } + func (s *SATranslationConfig) IsEnabled() bool { return len(s.NamespaceMappings) > 0 } diff --git a/config/config_test.go b/config/config_test.go index 8433e85..4c33275 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -135,3 +136,59 @@ func TestExampleChart(t *testing.T) { require.Equal(t, ConnectionType("mux-client"), cc.Remote.ConnectionType) require.Equal(t, "s2s-proxy-sample.example.tmprl.cloud:8233", cc.Remote.MuxAddressInfo.ConnectionString) } + +func TestDurationUnmarshalYAML(t *testing.T) { + type wrapper struct { + Delay Duration `yaml:"delay"` + } + cases := []struct { + name string + yamlInput string + want time.Duration + wantErr string + }{ + {name: "seconds", yamlInput: `delay: "30s"`, want: 30 * time.Second}, + {name: "minutes", yamlInput: `delay: "1m"`, want: time.Minute}, + {name: "mixed_units", yamlInput: `delay: "1m500ms"`, want: time.Minute + 500*time.Millisecond}, + {name: "explicit_zero", yamlInput: `delay: "0s"`, want: 0}, + {name: "explicit_zero_unitless", yamlInput: `delay: "0"`, want: 0}, + {name: "empty_string", yamlInput: `delay: ""`, want: 0}, + {name: "absent_field", yamlInput: `{}`, want: 0}, + {name: "invalid_format", yamlInput: `delay: "not-a-duration"`, wantErr: `invalid duration "not-a-duration"`}, + {name: "missing_unit_quoted", yamlInput: `delay: "30"`, wantErr: `invalid duration "30"`}, + // yaml.v3 coerces a bare integer scalar to a string when our UnmarshalYAML + // asks for one, so a unitless int hits the same "missing unit" path. + {name: "missing_unit_int", yamlInput: `delay: 30`, wantErr: `invalid duration "30"`}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var w wrapper + err := yaml.Unmarshal([]byte(c.yamlInput), &w) + if c.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), c.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, c.want, time.Duration(w.Delay)) + }) + } +} + +func TestDurationAsDuration(t *testing.T) { + cases := []struct { + name string + in Duration + want time.Duration + }{ + {name: "zero", in: 0, want: 0}, + {name: "positive", in: Duration(2*time.Minute + 30*time.Second), want: 2*time.Minute + 30*time.Second}, + {name: "negative", in: Duration(-500 * time.Millisecond), want: -500 * time.Millisecond}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + require.Equal(t, c.want, c.in.AsDuration()) + }) + } +} diff --git a/develop/config/cluster-b-mux-server-proxy.yaml b/develop/config/cluster-b-mux-server-proxy.yaml index 6da2885..9cb2f68 100644 --- a/develop/config/cluster-b-mux-server-proxy.yaml +++ b/develop/config/cluster-b-mux-server-proxy.yaml @@ -1,3 +1,4 @@ +muxManagerStartDelay: "0s" clusterConnections: - name: "b-inbound-server/b-outbound-server" local: diff --git a/proxy/cluster_connection_test.go b/proxy/cluster_connection_test.go index 8eeb640..0fcf407 100644 --- a/proxy/cluster_connection_test.go +++ b/proxy/cluster_connection_test.go @@ -14,12 +14,12 @@ import ( "go.temporal.io/server/common/log/tag" "google.golang.org/grpc" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/endtoendtest/testservices" "github.com/temporalio/s2s-proxy/logging" "github.com/temporalio/s2s-proxy/metrics" "github.com/temporalio/s2s-proxy/transport/grpcutil" - "github.com/temporalio/s2s-proxy/transport/mux" ) const ( @@ -29,7 +29,7 @@ const ( func init() { _ = os.Setenv("TEMPORAL_TEST_LOG_LEVEL", "error") - mux.MuxManagerStartDelay = 0 + common.GlobalPolicy.MuxManagerStartDelay = 0 } func getDynamicPorts(t *testing.T, num int) []string { diff --git a/proxy/proxy.go b/proxy/proxy.go index af9e736..95af0c7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -10,6 +10,7 @@ import ( "go.temporal.io/server/common/log/tag" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/logging" "github.com/temporalio/s2s-proxy/metrics" @@ -39,6 +40,9 @@ type ( func NewProxy(configProvider config.ConfigProvider, logProvider logging.LoggerProvider) *Proxy { s2sConfig := configProvider.GetS2SProxyConfig() + if s2sConfig.MuxManagerStartDelay != nil { + common.GlobalPolicy.UpdateMuxManagerStartDelay(s2sConfig.MuxManagerStartDelay.AsDuration()) + } ctx, cancel := context.WithCancel(context.Background()) proxy := &Proxy{ lifetime: ctx, diff --git a/proxy/test/echo_proxy_test.go b/proxy/test/echo_proxy_test.go index e52a5e2..8fd7ea0 100644 --- a/proxy/test/echo_proxy_test.go +++ b/proxy/test/echo_proxy_test.go @@ -14,16 +14,16 @@ import ( "go.temporal.io/server/client/history" "go.temporal.io/server/common/log" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/config" "github.com/temporalio/s2s-proxy/encryption" "github.com/temporalio/s2s-proxy/endtoendtest" - "github.com/temporalio/s2s-proxy/transport/mux" ) func init() { // silence info log spam _ = os.Setenv("TEMPORAL_TEST_LOG_LEVEL", "error") - mux.MuxManagerStartDelay = 0 + common.GlobalPolicy.MuxManagerStartDelay = 0 } var ( diff --git a/transport/mux/grpc_mux_test.go b/transport/mux/grpc_mux_test.go index 085467b..3c1a538 100644 --- a/transport/mux/grpc_mux_test.go +++ b/transport/mux/grpc_mux_test.go @@ -20,7 +20,7 @@ import ( ) func init() { - MuxManagerStartDelay = 0 + common.GlobalPolicy.MuxManagerStartDelay = 0 } func TestGRPCMux(t *testing.T) { diff --git a/transport/mux/multi_mux_manager.go b/transport/mux/multi_mux_manager.go index 8c9312e..01c264e 100644 --- a/transport/mux/multi_mux_manager.go +++ b/transport/mux/multi_mux_manager.go @@ -14,12 +14,11 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/metrics" "github.com/temporalio/s2s-proxy/transport/mux/session" ) -var MuxManagerStartDelay = time.Minute - type ( multiMuxManager struct { lifetime context.Context @@ -172,8 +171,11 @@ func (m *multiMuxManager) Start() { tag.Name(m.name), tag.NewStringTag("sessions", sb.String())) } }() - // Allow the mux provider some time to provide connections - <-time.After(MuxManagerStartDelay) + // Allow the mux provider some time to provide connections. + // A MuxManagerStartDelay of 0 means skip the wait entirely. + if common.GlobalPolicy.MuxManagerStartDelay > 0 { + <-time.After(common.GlobalPolicy.MuxManagerStartDelay) + } }) } func (m *multiMuxManager) Describe() string { diff --git a/transport/mux/multi_mux_manager_test.go b/transport/mux/multi_mux_manager_test.go index 391ceac..8b79e6a 100644 --- a/transport/mux/multi_mux_manager_test.go +++ b/transport/mux/multi_mux_manager_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/require" "go.temporal.io/server/common/log" + "github.com/temporalio/s2s-proxy/common" "github.com/temporalio/s2s-proxy/endtoendtest/proxyassert" "github.com/temporalio/s2s-proxy/transport/mux/session" ) func init() { _ = os.Setenv("TEMPORAL_TEST_LOG_LEVEL", "error") - MuxManagerStartDelay = 0 + common.GlobalPolicy.MuxManagerStartDelay = 0 } func TestMultiMuxManager(t *testing.T) {