diff --git a/cluster/cluster/adaptivesvc/cluster_invoker.go b/cluster/cluster/adaptivesvc/cluster_invoker.go index ef960e0291..4ea8e0e163 100644 --- a/cluster/cluster/adaptivesvc/cluster_invoker.go +++ b/cluster/cluster/adaptivesvc/cluster_invoker.go @@ -40,7 +40,7 @@ import ( ) type adaptiveServiceClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newAdaptiveServiceClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/available/cluster_invoker.go b/cluster/cluster/available/cluster_invoker.go index d645977ec3..d8840a8e45 100644 --- a/cluster/cluster/available/cluster_invoker.go +++ b/cluster/cluster/available/cluster_invoker.go @@ -34,7 +34,7 @@ import ( ) type availableClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } // NewClusterInvoker returns a availableCluster invoker instance diff --git a/cluster/cluster/base/cluster_invoker.go b/cluster/cluster/base/cluster_invoker.go index 80f0ca89e4..1492100af0 100644 --- a/cluster/cluster/base/cluster_invoker.go +++ b/cluster/cluster/base/cluster_invoker.go @@ -18,12 +18,16 @@ // Package base implements invoker for the manipulation of cluster strategy. package base +import ( + "sync/atomic" +) + import ( "github.com/dubbogo/gost/log/logger" perrors "github.com/pkg/errors" - "go.uber.org/atomic" + uberatomic "go.uber.org/atomic" ) import ( @@ -38,15 +42,15 @@ import ( type BaseClusterInvoker struct { Directory directory.Directory AvailableCheck bool - Destroyed *atomic.Bool - StickyInvoker base.Invoker + Destroyed *uberatomic.Bool + StickyInvoker atomic.Pointer[base.Invoker] } -func NewBaseClusterInvoker(directory directory.Directory) BaseClusterInvoker { - return BaseClusterInvoker{ +func NewBaseClusterInvoker(directory directory.Directory) *BaseClusterInvoker { + return &BaseClusterInvoker{ Directory: directory, AvailableCheck: true, - Destroyed: atomic.NewBool(false), + Destroyed: uberatomic.NewBool(false), } } @@ -62,8 +66,8 @@ func (invoker *BaseClusterInvoker) Destroy() { } func (invoker *BaseClusterInvoker) IsAvailable() bool { - if invoker.StickyInvoker != nil { - return invoker.StickyInvoker.IsAvailable() + if sticky := invoker.StickyInvoker.Load(); sticky != nil { + return (*sticky).IsAvailable() } return invoker.Directory.IsAvailable() } @@ -100,19 +104,21 @@ func (invoker *BaseClusterInvoker) DoSelect(lb loadbalance.LoadBalance, invocati // Get the service method sticky config if have sticky = url.GetMethodParamBool(invocation.MethodName(), constant.StickyKey, sticky) - if invoker.StickyInvoker != nil && !isInvoked(invoker.StickyInvoker, invokers) { - invoker.StickyInvoker = nil + stickyInvoker := invoker.StickyInvoker.Load() + if stickyInvoker != nil && !isInvoked(*stickyInvoker, invokers) { + invoker.StickyInvoker.Store(nil) + stickyInvoker = nil } if sticky && invoker.AvailableCheck && - invoker.StickyInvoker != nil && invoker.StickyInvoker.IsAvailable() && - (invoked == nil || !isInvoked(invoker.StickyInvoker, invoked)) { - return invoker.StickyInvoker + stickyInvoker != nil && (*stickyInvoker).IsAvailable() && + (invoked == nil || !isInvoked(*stickyInvoker, invoked)) { + return *stickyInvoker } selectedInvoker = invoker.doSelectInvoker(lb, invocation, invokers, invoked) if sticky { - invoker.StickyInvoker = selectedInvoker + invoker.StickyInvoker.Store(&selectedInvoker) } return selectedInvoker } diff --git a/cluster/cluster/base/cluster_invoker_test.go b/cluster/cluster/base/cluster_invoker_test.go index 454ae7df7a..a4399a3480 100644 --- a/cluster/cluster/base/cluster_invoker_test.go +++ b/cluster/cluster/base/cluster_invoker_test.go @@ -19,6 +19,7 @@ package base import ( "fmt" + "sync" "testing" ) @@ -73,3 +74,89 @@ func TestStickyNormalWhenError(t *testing.T) { result1 := base.DoSelect(random.NewRandomLoadBalance(), invocation.NewRPCInvocation(baseClusterInvokerMethodName, nil, nil), invokers, invoked) assert.NotEqual(t, result, result1) } + +// TestStickyConcurrentDoSelect verifies that concurrent calls to DoSelect +// with sticky enabled do not cause a data race on StickyInvoker. +func TestStickyConcurrentDoSelect(t *testing.T) { + var invokers []protocolbase.Invoker + for i := 0; i < 10; i++ { + url, _ := common.NewURL(fmt.Sprintf(baseClusterInvokerFormat, i)) + url.SetParam("sticky", "true") + invokers = append(invokers, clusterpkg.NewMockInvoker(url, 1)) + } + base := &BaseClusterInvoker{} + base.AvailableCheck = true + + lb := random.NewRandomLoadBalance() + invocation1 := invocation.NewRPCInvocation(baseClusterInvokerMethodName, nil, nil) + + const concurrency = 100 + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + invoked := make([]protocolbase.Invoker, 0) + result := base.DoSelect(lb, invocation1, invokers, invoked) + assert.NotNil(t, result) + }() + } + wg.Wait() +} + +// TestStickyConcurrentIsAvailableAndDoSelect verifies that concurrent +// IsAvailable and DoSelect calls do not cause a data race on StickyInvoker. +func TestStickyConcurrentIsAvailableAndDoSelect(t *testing.T) { + var invokers []protocolbase.Invoker + for i := 0; i < 10; i++ { + url, _ := common.NewURL(fmt.Sprintf(baseClusterInvokerFormat, i)) + url.SetParam("sticky", "true") + invokers = append(invokers, clusterpkg.NewMockInvoker(url, 1)) + } + + // Use NewBaseClusterInvoker so that Directory is initialized, + // allowing IsAvailable() to work without panicking. + dir := newMockDirectory(invokers) + base := NewBaseClusterInvoker(dir) + base.AvailableCheck = true + + lb := random.NewRandomLoadBalance() + invocation1 := invocation.NewRPCInvocation(baseClusterInvokerMethodName, nil, nil) + + // First DoSelect to set the sticky invoker so IsAvailable uses the sticky path + invoked := make([]protocolbase.Invoker, 0) + base.DoSelect(lb, invocation1, invokers, invoked) + + const concurrency = 100 + var wg sync.WaitGroup + wg.Add(concurrency * 2) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + base.IsAvailable() + }() + go func() { + defer wg.Done() + base.DoSelect(lb, invocation1, invokers, invoked) + }() + } + wg.Wait() +} + +// mockDirectory is a minimal directory.Directory implementation for testing. +type mockDirectory struct { + invokers []protocolbase.Invoker + url *common.URL +} + +func newMockDirectory(invokers []protocolbase.Invoker) *mockDirectory { + url, _ := common.NewURL(baseClusterInvokerFormat) + url.SetParam("sticky", "true") + return &mockDirectory{invokers: invokers, url: url} +} + +func (d *mockDirectory) GetURL() *common.URL { return d.url } +func (d *mockDirectory) IsAvailable() bool { return true } +func (d *mockDirectory) Destroy() {} +func (d *mockDirectory) List(protocolbase.Invocation) []protocolbase.Invoker { return d.invokers } +func (d *mockDirectory) Subscribe(*common.URL) error { return nil } diff --git a/cluster/cluster/broadcast/cluster_invoker.go b/cluster/cluster/broadcast/cluster_invoker.go index bace1e63e3..d938cad198 100644 --- a/cluster/cluster/broadcast/cluster_invoker.go +++ b/cluster/cluster/broadcast/cluster_invoker.go @@ -33,7 +33,7 @@ import ( ) type broadcastClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newBroadcastClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/failback/cluster_invoker.go b/cluster/cluster/failback/cluster_invoker.go index 3509aad838..2b89ccc525 100644 --- a/cluster/cluster/failback/cluster_invoker.go +++ b/cluster/cluster/failback/cluster_invoker.go @@ -49,7 +49,7 @@ import ( * Failback */ type failbackClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker once sync.Once ticker *time.Ticker diff --git a/cluster/cluster/failfast/cluster_invoker.go b/cluster/cluster/failfast/cluster_invoker.go index 4d04695ecf..89f4b4c083 100644 --- a/cluster/cluster/failfast/cluster_invoker.go +++ b/cluster/cluster/failfast/cluster_invoker.go @@ -29,7 +29,7 @@ import ( ) type failfastClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newFailfastClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/failover/cluster_invoker.go b/cluster/cluster/failover/cluster_invoker.go index 946c4e8f59..6799176049 100644 --- a/cluster/cluster/failover/cluster_invoker.go +++ b/cluster/cluster/failover/cluster_invoker.go @@ -39,7 +39,7 @@ import ( ) type failoverClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newFailoverClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/failsafe/cluster_invoker.go b/cluster/cluster/failsafe/cluster_invoker.go index 537690726f..71463afb64 100644 --- a/cluster/cluster/failsafe/cluster_invoker.go +++ b/cluster/cluster/failsafe/cluster_invoker.go @@ -42,7 +42,7 @@ import ( * */ type failsafeClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newFailsafeClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/forking/cluster_invoker.go b/cluster/cluster/forking/cluster_invoker.go index 326a89f682..4c5f98e9a2 100644 --- a/cluster/cluster/forking/cluster_invoker.go +++ b/cluster/cluster/forking/cluster_invoker.go @@ -38,7 +38,7 @@ import ( ) type forkingClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newForkingClusterInvoker(directory directory.Directory) protocolbase.Invoker { diff --git a/cluster/cluster/zoneaware/cluster_invoker.go b/cluster/cluster/zoneaware/cluster_invoker.go index 3086c39383..3da8bb57b9 100644 --- a/cluster/cluster/zoneaware/cluster_invoker.go +++ b/cluster/cluster/zoneaware/cluster_invoker.go @@ -38,7 +38,7 @@ import ( // 3. Evenly balance traffic between all registries based on each registry's weight. // 4. Pick anyone that's available. type zoneawareClusterInvoker struct { - base.BaseClusterInvoker + *base.BaseClusterInvoker } func newZoneawareClusterInvoker(directory directory.Directory) protocolbase.Invoker {