Skip to content

Separating the function to loop over data from the one to create batches for training#92

Open
anagainaru wants to merge 1 commit intomainfrom
loop-batch
Open

Separating the function to loop over data from the one to create batches for training#92
anagainaru wants to merge 1 commit intomainfrom
loop-batch

Conversation

@anagainaru
Copy link
Copy Markdown
Collaborator

Summary

Current behavior is to have a function get_cur_data_loaders that returns the dataloaders that is used for both looping for inference and for creating batches to train when doing continual learning. The current PR separates them into two separate functions.

Motivation & Context

Having two functions would allow us to control the granularity of the drift detectors (e.g. looking at element by element) without impacting the training (which should use the same batch size as what was used for the original training)

Approach

Introduced a new function get_cur_loop_loaders. If this function is not implemented in the model harness it returns by default self.get_cur_data_loaders()

API / CLI Changes

No changes, this is an optional function in the model harness.

Example usage:

diff --git a/examples/aeris/model.py b/examples/aeris/model.py
index 4de7fc4..2c8852c 100644
--- a/examples/aeris/model.py
+++ b/examples/aeris/model.py
@@ -91,18 +92,23 @@ class AERIS(BaseModelHarness):
     def get_cur_data_loaders(self) -> Tuple[DataLoader, DataLoader]:  # noqa: D102
         assert self._cur_train_loader is not None and self._cur_val_loader is not None
         return self._cur_train_loader, self._cur_val_loader

+    def get_cur_loop_loaders(self) -> Tuple[DataLoader, DataLoader]:  # noqa: D102
+        assert self._cur_loop_train_loader is not None and self._cur_loop_val_loader is not None
+        return self._cur_loop_train_loader, self._cur_loop_val_loader
+
     def get_hist_data_loaders(
         self,
     ) -> Tuple[Optional[DataLoader], Optional[DataLoader]]:
@@ -181,6 +185,14 @@ class AERIS(BaseModelHarness):
             ds_val, bs, shuffle=False, num_workers=nw, pin_memory=pin
         )

+        bs = self.cfg.data.batch_size
+        self._cur_loop_train_loader = make_loader(
+            ds_train, bs, shuffle=True, num_workers=nw, pin_memory=pin
+        )
+        self._cur_loop_val_loader = make_loader(
+            ds_val, bs, shuffle=False, num_workers=nw, pin_memory=pin
+        )
+
         self.window_idx += 1

     # --------------------------------------------------------------------- #

@anagainaru anagainaru requested a review from ScSteffen March 10, 2026 16:45
@ScSteffen
Copy link
Copy Markdown
Collaborator

I think this is a neccessary change - but very intrusive in the sense that all examples needs to be adapted.

Do you have an example that runs currently, so I can validate the changes locally?

@ScSteffen
Copy link
Copy Markdown
Collaborator

I think with this change we can further simplify:

==> the loop loader only needs a val_loader, since we are not training on this data

@anagainaru
Copy link
Copy Markdown
Collaborator Author

I think this is a neccessary change - but very intrusive in the sense that all examples needs to be adapted.

Do you have an example that runs currently, so I can validate the changes locally?

You do not need to change the examples, if this second function is not implemented we use the current behavior with looping over data using the training batch size.

@anagainaru
Copy link
Copy Markdown
Collaborator Author

I think with this change we can further simplify:

==> the loop loader only needs a val_loader, since we are not training on this data

Agree, I will make the change and update mnist so you can run using different batches if you specify in toml a data batch size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants