Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tutorials/machine_learning/ml_dataloader_resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
### \file
### \ingroup tutorial_ml
### \notebook -nodraw
### Example of resampling when one class is underrepresented in the dataset.
###
### \macro_code
### \macro_output
### \author Jonah Ascoli

import os

import ROOT
import torch

torch.manual_seed(42)

# Create an imbalanced dataset with two classes, one of which is underrepresented.
# Here, we'll create two files, one with even numbers and one with odd numbers,
# and then merge them to form a dataset with underrepresented odd numbers.
ROOT.RDataFrame(100000).Define("b1", "(int) 2 * rdfentry_").Define("b2", "(int) b1%2").Snapshot("tree", "major.root")
ROOT.RDataFrame(100).Define("b1", "(int) 2 * rdfentry_ + 1").Define("b2", "(int) b1%2").Snapshot("tree", "minor.root")
df_major = ROOT.RDataFrame("tree", "major.root")
df_minor = ROOT.RDataFrame("tree", "minor.root")

batch_size = 16
batches_in_memory = 10
num_epochs = 10

loss_fn = torch.nn.BCEWithLogitsLoss()


def train_model(model, optimizer, dataloader):
train, val = dataloader.train_test_split(test_size=0.2)
for _ in range(num_epochs):
model.train()
for X, y in train.as_torch():
optimizer.zero_grad()
loss = loss_fn(model(X), y)
loss.backward()
optimizer.step()
losses = []
for X, y in val.as_torch():
with torch.no_grad():
loss = loss_fn(model(X), y)
losses.append(loss.item())
print(f"Validation Loss: {sum(losses) / len(losses)}")


# First, let's try to create a dataloader without resampling and see how it handles the underrepresented class.
dl = ROOT.Experimental.ML.RDataLoader(
[df_major, df_minor],
batch_size=batch_size,
batches_in_memory=batches_in_memory,
target="b2",
set_seed=42,
)

basic_model = torch.nn.Linear(1, 1) # Simple linear model for binary classification
basic_optimizer = torch.optim.Adam(basic_model.parameters())

print("Training without resampling:")
train_model(basic_model, basic_optimizer, dl)

# Now, let's try the same thing with oversampling
# Strategy: more batches of the underrepresented class
# Takes more time per epoch, but each epoch is more effective
dl_oversampled = ROOT.Experimental.ML.RDataLoader(
[df_major, df_minor],
batch_size=batch_size,
batches_in_memory=batches_in_memory,
target="b2",
set_seed=42,
load_eager=True, # Must be enabled for resampling
sampling_type="oversampling", # Can also be "undersampling"
)

oversampling_model = torch.nn.Linear(1, 1)
oversampling_optimizer = torch.optim.Adam(oversampling_model.parameters())

print("Training with oversampling:")
train_model(oversampling_model, oversampling_optimizer, dl_oversampled)

os.remove("major.root")
os.remove("minor.root")
Loading