diff --git a/gcsfs/tests/perf/macrobenchmarks/workloads/hf-pytorch-lightning-cpu/helm_chart/llama_3_1_8b_cpu_sim.py b/gcsfs/tests/perf/macrobenchmarks/workloads/hf-pytorch-lightning-cpu/helm_chart/llama_3_1_8b_cpu_sim.py index 79bab470..feb7c5b9 100644 --- a/gcsfs/tests/perf/macrobenchmarks/workloads/hf-pytorch-lightning-cpu/helm_chart/llama_3_1_8b_cpu_sim.py +++ b/gcsfs/tests/perf/macrobenchmarks/workloads/hf-pytorch-lightning-cpu/helm_chart/llama_3_1_8b_cpu_sim.py @@ -284,6 +284,12 @@ def __init__(self): super().__init__() self.ckpt_time = 0.0 + def on_train_start(self, trainer, pl_module): + # Initialize timer at training start to avoid AttributeError when resuming mid-epoch + # (where on_train_epoch_start is skipped). + self.start_time = time.perf_counter() + self.ckpt_time = 0.0 + def on_train_epoch_start(self, trainer, pl_module): # Start timer at the beginning of the epoch to capture the first batch's data loading time self.start_time = time.perf_counter()