Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions deepfense/cli/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,24 @@ def validate_config(cfg):
@click.command()
@click.option("--config", "-c", required=True, type=click.Path(exists=True), help="Path to YAML config file")
@click.option("--resume", "-r", default=None, type=click.Path(exists=True), help="Resume from checkpoint")
def train(config, resume):
@click.option("--resume-mode", type=click.Choice(["1", "2"], case_sensitive=False), default="2",
help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)")
def train(config, resume, resume_mode):
"""
Train a DeepFense model.

Resume modes:

Mode 1 (epoch_restart): Load model weights, start from epoch 0 - useful when changing dataset
Mode 2 (continue): Load model weights + training state, continue from checkpoint

Example:

deepfense train --config config/train.yaml

deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth

deepfense train --config config/train.yaml --resume outputs/exp/best_model.pth --resume-mode 1
"""
# Load config
cfg = load_config(config)
Expand Down Expand Up @@ -150,7 +159,10 @@ def train(config, resume):
)

if resume:
trainer.load_checkpoint(resume)
resume_mode_int = int(resume_mode)
logger.info(f"Loading checkpoint: {resume}")
logger.info(f"Resume mode {resume_mode_int}: {'Restart from epoch 0' if resume_mode_int == 1 else 'Continue from checkpoint'}")
trainer.load_checkpoint(resume, resume_mode=resume_mode_int)

trainer.train()

23 changes: 21 additions & 2 deletions deepfense/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,29 @@ def save_checkpoint(self, state, is_best=False):
best_path = os.path.join(self.config.output_dir, "best.pth")
torch.save(state, best_path)

def load_checkpoint(self, path):
def load_checkpoint(self, path, resume_mode=2):
"""
Load checkpoint with different resume modes.

Args:
path: Path to checkpoint file
resume_mode:
1 = epoch_restart: Load only model weights, reset epoch to 0
2 = continue: Load model weights + training state (epoch, global_step, etc.)
"""
state = torch.load(path, map_location=self.device)
self.model.load_state_dict(state["model_state_dict"])
self.global_step = state.get("global_step", 0)

if resume_mode == 1:
# Mode 1: Start from epoch 0 - don't load training state
self.logger.info("[Resume Mode 1] Loaded model weights. Starting from epoch 0.")
self.global_step = 0
self.start_epoch = 0
else:
# Mode 2: Continue from checkpoint - load all training state
self.logger.info("[Resume Mode 2] Loaded model weights and training state.")
self.global_step = state.get("global_step", 0)
self.start_epoch = state.get("start_epoch", 0)

def train_step(self, batch):
"""Override in subclass."""
Expand Down
85 changes: 74 additions & 11 deletions deepfense/training/standard_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _unwrapped_model(self):

def train(self):
self.model.train()

# Initialize resume tracking
resume_batch_idx = getattr(self, '_resume_batch_idx', 0)

num_params = sum(p.numel() for p in self._unwrapped_model().parameters() if p.requires_grad)
if self.is_main:
self.logger.info(f"Trainable parameters: {num_params:,}")
Expand All @@ -143,7 +145,13 @@ def train(self):
# Initialize gradients
self.optimizer.zero_grad()

# SKIP BATCHES IF RESUMING MID-EPOCH
skip_batches = resume_batch_idx if epoch == self.start_epoch else 0

for batch_idx, batch in enumerate(loop):
# Skip already processed batches
if batch_idx < skip_batches:
continue
loss = self._train_step(batch, batch_idx, epoch)

epoch_loss_sum += loss
Expand Down Expand Up @@ -548,18 +556,40 @@ def save_checkpoint(self, epoch, step, is_best=False):

return fname

def load_checkpoint(self, path, load_optimizer=True):
def load_checkpoint(self, path, load_optimizer=True, resume_mode=2):
"""
Load checkpoint with different resume modes.

Args:
path: Path to checkpoint file
load_optimizer: Whether to load optimizer state
resume_mode:
1 = epoch_restart: Load only model weights, reset epoch to 0 and step to 0
2 = continue: Load model weights + training state (epoch, step, optimizer)
"""
state = torch.load(path, map_location=self.device)
self._unwrapped_model().load_state_dict(state["model_state"])
if load_optimizer:
opt_state = state.get("optimizer_state", None)
if opt_state:
self.optimizer.load_state_dict(opt_state)
self.start_epoch = state.get("epoch", 0)
self.global_step = state.get("step", 0)
self.best_metric = state.get("best_metric", self.best_metric)
if self.is_main:
self.logger.info(f"Loaded checkpoint from {path}")

if resume_mode == 1:
# Mode 1: Start from epoch 0 - load model only, reset training state
self.start_epoch = 0
self.global_step = 0
self.best_metric = -math.inf if getattr(self.config, "monitor_mode", "max") == "max" else math.inf
if self.is_main:
self.logger.info(f"[Resume Mode 1] Loaded model weights from {path}")
self.logger.info("[Resume Mode 1] Reset to epoch 0, global_step 0 (new dataset)")
else:
# Mode 2: Continue from checkpoint - load all training state
if load_optimizer:
opt_state = state.get("optimizer_state", None)
if opt_state:
self.optimizer.load_state_dict(opt_state)
self.start_epoch = state.get("epoch", 0)
self.global_step = state.get("step", 0)
self.best_metric = state.get("best_metric", self.best_metric)
if self.is_main:
self.logger.info(f"[Resume Mode 2] Loaded checkpoint from {path}")
self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, step {self.global_step}")

def infer(self, x):
self.model.eval()
Expand All @@ -572,3 +602,36 @@ def _disable(module):
module.momentum = 0

self.model.apply(_disable)

def resume_from_checkpoint(self, checkpoint_path, resume_mode=2):
"""
Load checkpoint and prepare to resume from the saved step or epoch 0.

Args:
checkpoint_path: Path to checkpoint file
resume_mode:
1 = epoch_restart: Load model weights only, start from epoch 0
2 = continue: Load model weights + training state (epoch, step) - default
"""
self.load_checkpoint(checkpoint_path, resume_mode=resume_mode)

# Calculate how many batches to skip in the current epoch (only for mode 2)
if resume_mode == 2 and self.start_epoch > 0:
# Get total steps per epoch (approximate)
steps_per_epoch = len(self.train_loader) // self.accum_steps

# Steps completed in the current epoch
steps_in_current_epoch = self.global_step % steps_per_epoch

# Store to skip in train loop
self._resume_batch_idx = steps_in_current_epoch * self.accum_steps
self._resume_epoch = self.start_epoch
else:
self._resume_batch_idx = 0
self._resume_epoch = 0

if resume_mode == 1:
self.logger.info(f"[Resume Mode 1] Starting training from epoch 0 with loaded model weights")
else:
self.logger.info(f"[Resume Mode 2] Resuming from epoch {self.start_epoch}, global_step {self.global_step}")
self.logger.info(f"Will skip {self._resume_batch_idx} batches in epoch {self.start_epoch}")
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--resume-mode", type=int, default=2, choices=[1, 2],
help="Resume mode: 1 = restart from epoch 0 (new dataset), 2 = continue from checkpoint (default)")
args = parser.parse_args()

ddp = setup_distributed()
Expand Down Expand Up @@ -231,7 +233,9 @@ def main():
)

if args.resume:
trainer.load_checkpoint(args.resume)
logger.info(f"Loading checkpoint: {args.resume}")
logger.info(f"Resume mode {args.resume_mode}: {'Restart from epoch 0' if args.resume_mode == 1 else 'Continue from checkpoint'}")
trainer.resume_from_checkpoint(args.resume, resume_mode=args.resume_mode)

trainer.train()

Expand Down