diff --git a/random_weighted.go b/random_weighted.go index 61ee717..d5802a8 100644 --- a/random_weighted.go +++ b/random_weighted.go @@ -62,6 +62,8 @@ func (rw *RandW) All() map[interface{}]int { func (rw *RandW) RemoveAll() { rw.items = make([]*randWeighted, 0) rw.r = rand.New(rand.NewSource(time.Now().Unix())) + rw.sumOfWeights = 0 + rw.n = 0 } // Reset resets the balancing algorithm. diff --git a/random_weighted_test.go b/random_weighted_test.go index 5b4d51a..d161b08 100644 --- a/random_weighted_test.go +++ b/random_weighted_test.go @@ -31,6 +31,19 @@ func TestRandW_Next(t *testing.T) { t.Error("the algorithm is wrong", results) } + all := w.All() + countOK := 0 + for index := range all { + if (index == "server1" && all[index] == 5) || + (index == "server2" && all[index] == 2) || + (index == "server3" && all[index] == 3) { + countOK++ + } + } + if countOK != 3 { + t.Error("the algorithm is wrong") + } + w.RemoveAll() w.Add("server1", 7) w.Add("server2", 9) @@ -48,6 +61,17 @@ func TestRandW_Next(t *testing.T) { // } t.Log("the results: ", results) + + w.RemoveAll() + next := w.Next() + if next != nil { + t.Error("the algorithm is wrong") + } + w.Add("server1", 3) + next = w.Next() + if next == nil { + t.Error("the algorithm is wrong") + } } func checkResults(v, min, max int) bool { diff --git a/roundrobin_weighted.go b/roundrobin_weighted.go index 4213acf..498658c 100644 --- a/roundrobin_weighted.go +++ b/roundrobin_weighted.go @@ -93,14 +93,8 @@ func (w *RRW) Next() interface{} { } func gcd(x, y int) int { - var t int - for { - t = (x % y) - if t > 0 { - x = y - y = t - } else { - return y - } + for y != 0 { + x, y = y, x%y } + return x } diff --git a/roundrobin_weighted_test.go b/roundrobin_weighted_test.go index 0f156f4..cecc952 100644 --- a/roundrobin_weighted_test.go +++ b/roundrobin_weighted_test.go @@ -31,6 +31,19 @@ func TestRRW_Next(t *testing.T) { t.Error("the algorithm is wrong", results) } + all := w.All() + countOK := 0 + for index := range all { + if (index == "server1" && all[index] == 5) || + (index == "server2" && all[index] == 2) || + (index == "server3" && all[index] == 3) { + countOK++ + } + } + if countOK != 3 { + t.Error("the algorithm is wrong") + } + w.RemoveAll() w.Add("server1", 7) w.Add("server2", 9) @@ -46,4 +59,35 @@ func TestRRW_Next(t *testing.T) { if results["server1"] != 7000 || results["server2"] != 9000 || results["server3"] != 13000 { t.Error("the algorithm is wrong", results) } + + w.RemoveAll() + next := w.Next() + if next != nil { + t.Error("the algorithm is wrong") + } + w.Add("server1", 3) + next = w.Next() + if next == nil { + t.Error("the algorithm is wrong") + } +} + +func TestGCB(t *testing.T) { + tests := []struct { + name string + args [2]int + want int + }{ + {"0,0", [2]int{0, 0}, 0}, + {"1997,615", [2]int{1997, 6150}, 1}, + {"481,221", [2]int{481, 221}, 13}, + {"12,18", [2]int{12, 18}, 6}, + } + for _, tt := range tests { + n1 := tt.args[0] + n2 := tt.args[1] + if got := gcd(n1, n2); got != tt.want { + t.Errorf("gcb(%v, %v) = %v ; want = %v", n1, n2, got, tt.want) + } + } } diff --git a/smooth_weighted.go b/smooth_weighted.go index 0633c46..cbb8bcb 100644 --- a/smooth_weighted.go +++ b/smooth_weighted.go @@ -66,23 +66,14 @@ func (w *SW) All() map[interface{}]int { // Next returns next selected server. func (w *SW) Next() interface{} { - i := w.nextWeighted() - if i == nil { + switch w.n { + case 0: return nil + case 1: + return w.items[0].Item + default: + return nextSmoothWeighted(w.items).Item } - return i.Item -} - -// nextWeighted returns next selected weighted object. -func (w *SW) nextWeighted() *smoothWeighted { - if w.n == 0 { - return nil - } - if w.n == 1 { - return w.items[0] - } - - return nextSmoothWeighted(w.items) } //https://github.com/phusion/nginx/commit/27e94984486058d73157038f7950a0a36ecc6e35 @@ -92,10 +83,6 @@ func nextSmoothWeighted(items []*smoothWeighted) (best *smoothWeighted) { for i := 0; i < len(items); i++ { w := items[i] - if w == nil { - continue - } - w.CurrentWeight += w.EffectiveWeight total += w.EffectiveWeight if w.EffectiveWeight < w.Weight { diff --git a/smooth_weighted_test.go b/smooth_weighted_test.go index ee288e8..2f9d2c0 100644 --- a/smooth_weighted_test.go +++ b/smooth_weighted_test.go @@ -31,6 +31,19 @@ func TestSW_Next(t *testing.T) { t.Error("the algorithm is wrong") } + all := w.All() + countOK := 0 + for index := range all { + if (index == "server1" && all[index] == 5) || + (index == "server2" && all[index] == 2) || + (index == "server3" && all[index] == 3) { + countOK++ + } + } + if countOK != 3 { + t.Error("the algorithm is wrong") + } + w.RemoveAll() w.Add("server1", 7) w.Add("server2", 9) @@ -46,4 +59,15 @@ func TestSW_Next(t *testing.T) { if results["server1"] != 7000 || results["server2"] != 9000 || results["server3"] != 13000 { t.Error("the algorithm is wrong") } + + w.RemoveAll() + next := w.Next() + if next != nil { + t.Error("the algorithm is wrong") + } + w.Add("server1", 3) + next = w.Next() + if next == nil { + t.Error("the algorithm is wrong") + } }