diff --git a/deepfense/cli/commands/train.py b/deepfense/cli/commands/train.py index 2fc122e..0b211cf 100644 --- a/deepfense/cli/commands/train.py +++ b/deepfense/cli/commands/train.py @@ -79,15 +79,24 @@ def validate_config(cfg): @click.command() @click.option("--config", "-c", required=True, type=click.Path(exists=True), help="Path to YAML config file") @click.option("--resume", "-r", default=None, type=click.Path(exists=True), help="Resume from checkpoint") -def train(config, resume): +@click.option("--resume-mode", type=click.Choice(["1", "2"], case_sensitive=False), default="2", + help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)") +def train(config, resume, resume_mode): """ Train a DeepFense model. + Resume modes: + + Mode 1 (epoch_restart): Load model weights, start from epoch 0 - useful when changing dataset + Mode 2 (continue): Load model weights + training state, continue from checkpoint + Example: deepfense train --config config/train.yaml deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth + + deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth --resume-mode 1 """ # Load config cfg = load_config(config) @@ -150,7 +159,10 @@ def train(config, resume): ) if resume: - trainer.load_checkpoint(resume) + resume_mode_int = int(resume_mode) + logger.info(f"Loading checkpoint: {resume}") + logger.info(f"Resume mode {resume_mode_int}: {'Restart from epoch 0' if resume_mode_int == 1 else 'Continue from checkpoint'}") + trainer.load_checkpoint(resume, resume_mode=resume_mode_int) trainer.train() diff --git a/deepfense/training/base_trainer.py b/deepfense/training/base_trainer.py index 36e1c66..c42d49f 100644 --- a/deepfense/training/base_trainer.py +++ b/deepfense/training/base_trainer.py @@ -27,10 +27,29 @@ def save_checkpoint(self, state, is_best=False): best_path = os.path.join(self.config.output_dir, "best.pth") torch.save(state, best_path) - def load_checkpoint(self, path): + def load_checkpoint(self, path, resume_mode=2): + """ + Load checkpoint with different resume modes. + + Args: + path: Path to checkpoint file + resume_mode: + 1 = epoch_restart: Load only model weights, reset epoch to 0 + 2 = continue: Load model weights + training state (epoch, global_step, etc.) + """ state = torch.load(path, map_location=self.device) self.model.load_state_dict(state["model_state_dict"]) - self.global_step = state.get("global_step", 0) + + if resume_mode == 1: + # Mode 1: Start from epoch 0 - don't load training state + self.logger.info("[Resume Mode 1] Loaded model weights. Starting from epoch 0.") + self.global_step = 0 + self.start_epoch = 0 + else: + # Mode 2: Continue from checkpoint - load all training state + self.logger.info("[Resume Mode 2] Loaded model weights and training state.") + self.global_step = state.get("global_step", 0) + self.start_epoch = state.get("start_epoch", 0) def train_step(self, batch): """Override in subclass.""" diff --git a/deepfense/training/standard_trainer.py b/deepfense/training/standard_trainer.py index 0f2a6b2..3701ebd 100644 --- a/deepfense/training/standard_trainer.py +++ b/deepfense/training/standard_trainer.py @@ -116,7 +116,9 @@ def _unwrapped_model(self): def train(self): self.model.train() - + # Initialize resume tracking + resume_batch_idx = getattr(self, '_resume_batch_idx', 0) + num_params = sum(p.numel() for p in self._unwrapped_model().parameters() if p.requires_grad) if self.is_main: self.logger.info(f"Trainable parameters: {num_params:,}") @@ -143,7 +145,13 @@ def train(self): # Initialize gradients self.optimizer.zero_grad() + # SKIP BATCHES IF RESUMING MID-EPOCH + skip_batches = resume_batch_idx if epoch == self.start_epoch else 0 + for batch_idx, batch in enumerate(loop): + # Skip already processed batches + if batch_idx < skip_batches: + continue loss = self._train_step(batch, batch_idx, epoch) epoch_loss_sum += loss @@ -548,18 +556,40 @@ def save_checkpoint(self, epoch, step, is_best=False): return fname - def load_checkpoint(self, path, load_optimizer=True): + def load_checkpoint(self, path, load_optimizer=True, resume_mode=2): + """ + Load checkpoint with different resume modes. + + Args: + path: Path to checkpoint file + load_optimizer: Whether to load optimizer state + resume_mode: + 1 = epoch_restart: Load only model weights, reset epoch to 0 and step to 0 + 2 = continue: Load model weights + training state (epoch, step, optimizer) + """ state = torch.load(path, map_location=self.device) self._unwrapped_model().load_state_dict(state["model_state"]) - if load_optimizer: - opt_state = state.get("optimizer_state", None) - if opt_state: - self.optimizer.load_state_dict(opt_state) - self.start_epoch = state.get("epoch", 0) - self.global_step = state.get("step", 0) - self.best_metric = state.get("best_metric", self.best_metric) - if self.is_main: - self.logger.info(f"Loaded checkpoint from {path}") + + if resume_mode == 1: + # Mode 1: Start from epoch 0 - load model only, reset training state + self.start_epoch = 0 + self.global_step = 0 + self.best_metric = -math.inf if getattr(self.config, "monitor_mode", "max") == "max" else math.inf + if self.is_main: + self.logger.info(f"[Resume Mode 1] Loaded model weights from {path}") + self.logger.info("[Resume Mode 1] Reset to epoch 0, global_step 0 (new dataset)") + else: + # Mode 2: Continue from checkpoint - load all training state + if load_optimizer: + opt_state = state.get("optimizer_state", None) + if opt_state: + self.optimizer.load_state_dict(opt_state) + self.start_epoch = state.get("epoch", 0) + self.global_step = state.get("step", 0) + self.best_metric = state.get("best_metric", self.best_metric) + if self.is_main: + self.logger.info(f"[Resume Mode 2] Loaded checkpoint from {path}") + self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, step {self.global_step}") def infer(self, x): self.model.eval() @@ -572,3 +602,36 @@ def _disable(module): module.momentum = 0 self.model.apply(_disable) + + def resume_from_checkpoint(self, checkpoint_path, resume_mode=2): + """ + Load checkpoint and prepare to resume from the saved step or epoch 0. + + Args: + checkpoint_path: Path to checkpoint file + resume_mode: + 1 = epoch_restart: Load model weights only, start from epoch 0 + 2 = continue: Load model weights + training state (epoch, step) - default + """ + self.load_checkpoint(checkpoint_path, resume_mode=resume_mode) + + # Calculate how many batches to skip in the current epoch (only for mode 2) + if resume_mode == 2 and self.start_epoch > 0: + # Get total steps per epoch (approximate) + steps_per_epoch = len(self.train_loader) // self.accum_steps + + # Steps completed in the current epoch + steps_in_current_epoch = self.global_step % steps_per_epoch + + # Store to skip in train loop + self._resume_batch_idx = steps_in_current_epoch * self.accum_steps + self._resume_epoch = self.start_epoch + else: + self._resume_batch_idx = 0 + self._resume_epoch = 0 + + if resume_mode == 1: + self.logger.info(f"[Resume Mode 1] Starting training from epoch 0 with loaded model weights") + else: + self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, global_step {self.global_step}") + self.logger.info(f"Will skip {self._resume_batch_idx} batches in epoch {self.start_epoch}") diff --git a/train.py b/train.py index e7913eb..2c0b0d0 100644 --- a/train.py +++ b/train.py @@ -127,6 +127,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path to YAML config") parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") + parser.add_argument("--resume-mode", type=int, default=2, choices=[1, 2], + help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)") args = parser.parse_args() ddp = setup_distributed() @@ -231,7 +233,9 @@ def main(): ) if args.resume: - trainer.load_checkpoint(args.resume) + logger.info(f"Loading checkpoint: {args.resume}") + logger.info(f"Resume mode {args.resume_mode}: {'Restart from epoch 0' if args.resume_mode == 1 else 'Continue from checkpoint'}") + trainer.resume_from_checkpoint(args.resume, resume_mode=args.resume_mode) trainer.train()