Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions tests/python/test_final_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def test_final_model_selection_best_validation_ci_replicated(scorer, class_weigh
}
loss_f = loss_f_dict[est.parameters_.scorer]

def eval(individual, sample=None, log=False):
def eval_with_sklearn(individual, sample=None, log=False):

if sample is None:
sample = np.arange(len(data.y))

Expand All @@ -141,14 +142,19 @@ def eval(individual, sample=None, log=False):

if est.parameters_.scorer in ["log", "average_precision_score"]:
y_pred = np.array(individual.predict_proba(data)).astype(float)

if est.parameters_.scorer == "log":
eps = 0.001
y_pred = np.clip(y_pred, eps, 1-eps)
else:
y_pred = np.array(individual.predict(data)).astype(float)

if log: # silencing eval() during bootstrap, but enabling detailed info when re-calculating losses and comparing with brush's metrics

if log: # silencing eval_with_sklearn() during bootstrap, but enabling detailed info when re-calculating losses and comparing with brush's metrics
print('evaluating', individual.program.get_model())
print(np.round(y, 2))
print(np.round(y_pred, 2))
print('(sorted)', np.sort(y_pred))
print('rounded y', np.round(y, 2))
print('rounded preds', np.round(y_pred, 2))
print('sorted preds', np.sort(y_pred))

if est.class_weights not in ['unbalanced'] and est.parameters_.scorer not in ['balanced_accuracy']:
sample_weight = None
Expand Down Expand Up @@ -181,35 +187,39 @@ def eval(individual, sample=None, log=False):
print('(eval) sample weights', sample_weight)
print('(eval) loss', loss_f(y[sample], y_pred[sample], sample_weight=sample_weight[sample]))

return loss_f(y[sample], y_pred[sample], sample_weight=sample_weight[sample])
calculated_loss = loss_f(y[sample], y_pred[sample], sample_weight=sample_weight[sample])
print('calculated loss:', calculated_loss)
return calculated_loss
else: # Cases where we ignore weights
if log:
print('(eval) using no class weights')
print('(eval) sample weights not defined. using unbalanced version')
print('(eval) loss', loss_f(y[sample], y_pred[sample]))

return loss_f(y[sample], y_pred[sample])
calculated_loss = loss_f(y[sample], y_pred[sample])
print('calculated loss:', calculated_loss)
return calculated_loss

# Bootstrap validation samples
print("scorer and class weights;", scorer, class_weights)
print("original loss", est.best_estimator_.fitness.loss)
print("original loss_v", est.best_estimator_.fitness.loss_v)
print("recalculated loss", eval(est.best_estimator_))
print("recalculated loss", eval_with_sklearn(est.best_estimator_, log=True))

np.random.seed(0)
val_samples = [eval(est.best_estimator_, np.random.randint(len(y), size=len(y)))
val_samples = [eval_with_sklearn(est.best_estimator_, np.random.randint(len(y), size=len(y)))
for _ in range(100)]

lower_ci, upper_ci = np.quantile(val_samples, 0.05), np.quantile(val_samples, 0.95)
print(f"CI bounds: {lower_ci:.4f}, {upper_ci:.4f}")

# Evaluate all archive members
new_losses = [eval(ind, log=True) for ind in est.archive_]
new_losses = [eval_with_sklearn(ind, log=True) for ind in est.archive_]
candidates = [(l, p) for l, p in zip(new_losses, est.archive_) if lower_ci <= l <= upper_ci]

print('first arch ind', est.archive_[0].get_model())
print("Original losses from archive (brush's auprc) ", [ind.fitness.loss for ind in est.archive_])
print("Original losses_v from archive (brush's auprc) ", [ind.fitness.loss_v for ind in est.archive_])
print("Original losses from archive (brush's metric) ", [ind.fitness.loss for ind in est.archive_])
print("Original losses_v from archive (brush's metric) ", [ind.fitness.loss_v for ind in est.archive_])
print("Recalculated losses with sklearn (should match)", new_losses)
print(f"Num candidates in CI: {len(candidates)}")

Expand Down
Loading