From 4d8020f48a4b692f0f5b48bb946420c12f3b3469 Mon Sep 17 00:00:00 2001 From: Feidi Kallel Date: Thu, 18 Jun 2026 10:47:39 +0200 Subject: [PATCH 1/2] resume from checkpoint (steps) --- deepfense/training/standard_trainer.py | 32 +++++++++++++++++++++++++- train.py | 2 +- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/deepfense/training/standard_trainer.py b/deepfense/training/standard_trainer.py index 0f2a6b2..c0b28c2 100644 --- a/deepfense/training/standard_trainer.py +++ b/deepfense/training/standard_trainer.py @@ -116,7 +116,9 @@ def _unwrapped_model(self): def train(self): self.model.train() - + # Initialize resume tracking + resume_batch_idx = getattr(self, '_resume_batch_idx', 0) + num_params = sum(p.numel() for p in self._unwrapped_model().parameters() if p.requires_grad) if self.is_main: self.logger.info(f"Trainable parameters: {num_params:,}") @@ -143,7 +145,13 @@ def train(self): # Initialize gradients self.optimizer.zero_grad() + # SKIP BATCHES IF RESUMING MID-EPOCH + skip_batches = resume_batch_idx if epoch == self.start_epoch else 0 + for batch_idx, batch in enumerate(loop): + # Skip already processed batches + if batch_idx < skip_batches: + continue loss = self._train_step(batch, batch_idx, epoch) epoch_loss_sum += loss @@ -572,3 +580,25 @@ def _disable(module): module.momentum = 0 self.model.apply(_disable) + + def resume_from_checkpoint(self, checkpoint_path): + """Load checkpoint and prepare to resume exactly from the saved step.""" + self.load_checkpoint(checkpoint_path) + + # Calculate how many batches to skip in the current epoch + if self.start_epoch > 0: + # Get total steps per epoch (approximate) + steps_per_epoch = len(self.train_loader) // self.accum_steps + + # Steps completed in the current epoch + steps_in_current_epoch = self.global_step % steps_per_epoch + + # Store to skip in train loop + self._resume_batch_idx = steps_in_current_epoch * self.accum_steps + self._resume_epoch = self.start_epoch + else: + self._resume_batch_idx = 0 + self._resume_epoch = 0 + + self.logger.info(f"Resuming from epoch {self.start_epoch}, global_step {self.global_step}") + self.logger.info(f"Will skip {self._resume_batch_idx} batches in epoch {self.start_epoch}") diff --git a/train.py b/train.py index e7913eb..cbad4df 100644 --- a/train.py +++ b/train.py @@ -231,7 +231,7 @@ def main(): ) if args.resume: - trainer.load_checkpoint(args.resume) + trainer.resume_from_checkpoint(args.resume) trainer.train() From 3538bcf7d2e30439ddd926bc5d5a8715f1cfb7eb Mon Sep 17 00:00:00 2001 From: Feidi Kallel Date: Thu, 18 Jun 2026 11:46:08 +0200 Subject: [PATCH 2/2] implement two resume modes (1 = restart, 2=continue(default)) --- deepfense/cli/commands/train.py | 16 ++++++- deepfense/training/base_trainer.py | 23 ++++++++- deepfense/training/standard_trainer.py | 65 +++++++++++++++++++------- train.py | 6 ++- 4 files changed, 89 insertions(+), 21 deletions(-) diff --git a/deepfense/cli/commands/train.py b/deepfense/cli/commands/train.py index 2fc122e..0b211cf 100644 --- a/deepfense/cli/commands/train.py +++ b/deepfense/cli/commands/train.py @@ -79,15 +79,24 @@ def validate_config(cfg): @click.command() @click.option("--config", "-c", required=True, type=click.Path(exists=True), help="Path to YAML config file") @click.option("--resume", "-r", default=None, type=click.Path(exists=True), help="Resume from checkpoint") -def train(config, resume): +@click.option("--resume-mode", type=click.Choice(["1", "2"], case_sensitive=False), default="2", + help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)") +def train(config, resume, resume_mode): """ Train a DeepFense model. + Resume modes: + + Mode 1 (epoch_restart): Load model weights, start from epoch 0 - useful when changing dataset + Mode 2 (continue): Load model weights + training state, continue from checkpoint + Example: deepfense train --config config/train.yaml deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth + + deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth --resume-mode 1 """ # Load config cfg = load_config(config) @@ -150,7 +159,10 @@ def train(config, resume): ) if resume: - trainer.load_checkpoint(resume) + resume_mode_int = int(resume_mode) + logger.info(f"Loading checkpoint: {resume}") + logger.info(f"Resume mode {resume_mode_int}: {'Restart from epoch 0' if resume_mode_int == 1 else 'Continue from checkpoint'}") + trainer.load_checkpoint(resume, resume_mode=resume_mode_int) trainer.train() diff --git a/deepfense/training/base_trainer.py b/deepfense/training/base_trainer.py index 36e1c66..c42d49f 100644 --- a/deepfense/training/base_trainer.py +++ b/deepfense/training/base_trainer.py @@ -27,10 +27,29 @@ def save_checkpoint(self, state, is_best=False): best_path = os.path.join(self.config.output_dir, "best.pth") torch.save(state, best_path) - def load_checkpoint(self, path): + def load_checkpoint(self, path, resume_mode=2): + """ + Load checkpoint with different resume modes. + + Args: + path: Path to checkpoint file + resume_mode: + 1 = epoch_restart: Load only model weights, reset epoch to 0 + 2 = continue: Load model weights + training state (epoch, global_step, etc.) + """ state = torch.load(path, map_location=self.device) self.model.load_state_dict(state["model_state_dict"]) - self.global_step = state.get("global_step", 0) + + if resume_mode == 1: + # Mode 1: Start from epoch 0 - don't load training state + self.logger.info("[Resume Mode 1] Loaded model weights. Starting from epoch 0.") + self.global_step = 0 + self.start_epoch = 0 + else: + # Mode 2: Continue from checkpoint - load all training state + self.logger.info("[Resume Mode 2] Loaded model weights and training state.") + self.global_step = state.get("global_step", 0) + self.start_epoch = state.get("start_epoch", 0) def train_step(self, batch): """Override in subclass.""" diff --git a/deepfense/training/standard_trainer.py b/deepfense/training/standard_trainer.py index c0b28c2..3701ebd 100644 --- a/deepfense/training/standard_trainer.py +++ b/deepfense/training/standard_trainer.py @@ -556,18 +556,40 @@ def save_checkpoint(self, epoch, step, is_best=False): return fname - def load_checkpoint(self, path, load_optimizer=True): + def load_checkpoint(self, path, load_optimizer=True, resume_mode=2): + """ + Load checkpoint with different resume modes. + + Args: + path: Path to checkpoint file + load_optimizer: Whether to load optimizer state + resume_mode: + 1 = epoch_restart: Load only model weights, reset epoch to 0 and step to 0 + 2 = continue: Load model weights + training state (epoch, step, optimizer) + """ state = torch.load(path, map_location=self.device) self._unwrapped_model().load_state_dict(state["model_state"]) - if load_optimizer: - opt_state = state.get("optimizer_state", None) - if opt_state: - self.optimizer.load_state_dict(opt_state) - self.start_epoch = state.get("epoch", 0) - self.global_step = state.get("step", 0) - self.best_metric = state.get("best_metric", self.best_metric) - if self.is_main: - self.logger.info(f"Loaded checkpoint from {path}") + + if resume_mode == 1: + # Mode 1: Start from epoch 0 - load model only, reset training state + self.start_epoch = 0 + self.global_step = 0 + self.best_metric = -math.inf if getattr(self.config, "monitor_mode", "max") == "max" else math.inf + if self.is_main: + self.logger.info(f"[Resume Mode 1] Loaded model weights from {path}") + self.logger.info("[Resume Mode 1] Reset to epoch 0, global_step 0 (new dataset)") + else: + # Mode 2: Continue from checkpoint - load all training state + if load_optimizer: + opt_state = state.get("optimizer_state", None) + if opt_state: + self.optimizer.load_state_dict(opt_state) + self.start_epoch = state.get("epoch", 0) + self.global_step = state.get("step", 0) + self.best_metric = state.get("best_metric", self.best_metric) + if self.is_main: + self.logger.info(f"[Resume Mode 2] Loaded checkpoint from {path}") + self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, step {self.global_step}") def infer(self, x): self.model.eval() @@ -581,12 +603,20 @@ def _disable(module): self.model.apply(_disable) - def resume_from_checkpoint(self, checkpoint_path): - """Load checkpoint and prepare to resume exactly from the saved step.""" - self.load_checkpoint(checkpoint_path) + def resume_from_checkpoint(self, checkpoint_path, resume_mode=2): + """ + Load checkpoint and prepare to resume from the saved step or epoch 0. - # Calculate how many batches to skip in the current epoch - if self.start_epoch > 0: + Args: + checkpoint_path: Path to checkpoint file + resume_mode: + 1 = epoch_restart: Load model weights only, start from epoch 0 + 2 = continue: Load model weights + training state (epoch, step) - default + """ + self.load_checkpoint(checkpoint_path, resume_mode=resume_mode) + + # Calculate how many batches to skip in the current epoch (only for mode 2) + if resume_mode == 2 and self.start_epoch > 0: # Get total steps per epoch (approximate) steps_per_epoch = len(self.train_loader) // self.accum_steps @@ -600,5 +630,8 @@ def resume_from_checkpoint(self, checkpoint_path): self._resume_batch_idx = 0 self._resume_epoch = 0 - self.logger.info(f"Resuming from epoch {self.start_epoch}, global_step {self.global_step}") + if resume_mode == 1: + self.logger.info(f"[Resume Mode 1] Starting training from epoch 0 with loaded model weights") + else: + self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, global_step {self.global_step}") self.logger.info(f"Will skip {self._resume_batch_idx} batches in epoch {self.start_epoch}") diff --git a/train.py b/train.py index cbad4df..2c0b0d0 100644 --- a/train.py +++ b/train.py @@ -127,6 +127,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to YAML config") parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") + parser.add_argument("--resume-mode", type=int, default=2, choices=[1, 2], + help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)") args = parser.parse_args() ddp = setup_distributed() @@ -231,7 +233,9 @@ def main(): ) if args.resume: - trainer.resume_from_checkpoint(args.resume) + logger.info(f"Loading checkpoint: {args.resume}") + logger.info(f"Resume mode {args.resume_mode}: {'Restart from epoch 0' if args.resume_mode == 1 else 'Continue from checkpoint'}") + trainer.resume_from_checkpoint(args.resume, resume_mode=args.resume_mode) trainer.train()