Skip to content

Commit 2f697f8

Browse files
authored
fix: gphedge duplicates (#613)
* red test * fix
1 parent 5b8fb73 commit 2f697f8

2 files changed

Lines changed: 39 additions & 0 deletions

File tree

bayes_opt/acquisition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,15 @@ def suggest(
13161316
]
13171317
self.previous_candidates = np.array(x_max)
13181318
idx = self._sample_idx_from_softmax_gains(random_state=random_state)
1319+
if not target_space._allow_duplicate_points and x_max[idx] in target_space:
1320+
# If every candidate is a duplicate, keep the original choice.
1321+
# Avoiding duplicates then requires generating new candidates, which
1322+
# is outside GPHedge's candidate-selection responsibility.
1323+
non_duplicate_idx = [idx for idx, x in enumerate(x_max) if x not in target_space]
1324+
if len(non_duplicate_idx) > 0:
1325+
cumsum_softmax_g = np.cumsum(softmax(self.gains[non_duplicate_idx]))
1326+
r = random_state.rand()
1327+
idx = non_duplicate_idx[np.argmax(r <= cumsum_softmax_g)]
13191328
return x_max[idx]
13201329

13211330
def get_acquisition_params(self) -> dict[str, Any]:

tests/test_acquisition.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ def constrained_target_space(target_func):
6969
)
7070

7171

72+
class FixedSuggestion(acquisition.AcquisitionFunction):
73+
def __init__(self, suggestion):
74+
super().__init__()
75+
self.suggestion = np.array(suggestion)
76+
77+
def suggest(self, *args, **kwargs):
78+
return self.suggestion
79+
80+
def base_acq(self, mean, std):
81+
return mean
82+
83+
def get_acquisition_params(self):
84+
return {}
85+
86+
def set_acquisition_params(self, params): ...
87+
88+
7289
class MockAcquisition(acquisition.AcquisitionFunction):
7390
def __init__(self):
7491
super().__init__()
@@ -355,6 +372,19 @@ def predict(self, x):
355372
assert good_index == acq._sample_idx_from_softmax_gains(random_state=random_state)
356373

357374

375+
def test_gphedge_skips_duplicate_candidate_when_unique_candidate_exists(gp, target_space, random_state):
376+
duplicate = np.array([2.5, 0.5])
377+
unique = np.array([3.0, 1.0])
378+
target_space.register(duplicate, target=sum(duplicate))
379+
380+
acq = acquisition.GPHedge(base_acquisitions=[FixedSuggestion(duplicate), FixedSuggestion(unique)])
381+
acq.gains = np.array([100.0, 0.0])
382+
383+
suggestion = acq.suggest(gp=gp, target_space=target_space, fit_gp=False, random_state=random_state)
384+
385+
np.testing.assert_array_equal(suggestion, unique)
386+
387+
358388
def test_gphedge_integration(gp, target_space, random_state):
359389
base_acq1 = acquisition.UpperConfidenceBound()
360390
base_acq2 = acquisition.ProbabilityOfImprovement(xi=0.01)

0 commit comments

Comments
 (0)