diff --git a/diffusion_planner/diffusion_planner/model/module/decoder.py b/diffusion_planner/diffusion_planner/model/module/decoder.py index 4f4077e1..52b7c385 100644 --- a/diffusion_planner/diffusion_planner/model/module/decoder.py +++ b/diffusion_planner/diffusion_planner/model/module/decoder.py @@ -276,6 +276,7 @@ def __init__(self, config): hidden_dim=config.hidden_dim, heads=config.num_heads, dropout=dpr, + use_activation_checkpointing=getattr(config, "use_activation_checkpointing", True), ) self.turn_indicator_predictor = nn.Linear( 2 * (self._future_len // 10) + config.hidden_dim, TURN_INDICATOR_OUTPUT_DIM diff --git a/diffusion_planner/diffusion_planner/model/module/dit.py b/diffusion_planner/diffusion_planner/model/module/dit.py index 233c926e..4bdac8f7 100644 --- a/diffusion_planner/diffusion_planner/model/module/dit.py +++ b/diffusion_planner/diffusion_planner/model/module/dit.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from timm.models.layers import Mlp +from torch.utils.checkpoint import checkpoint def modulate(x, shift, scale): @@ -43,13 +44,19 @@ def forward(self, x, cross_c, y, attn_mask): x = ( x + gate_msa - * self.attn(modulated_x, modulated_x, modulated_x, key_padding_mask=attn_mask)[0] + * self.attn( + modulated_x, + modulated_x, + modulated_x, + key_padding_mask=attn_mask, + need_weights=False, + )[0] ) modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp) x = x + gate_mlp * self.mlp1(modulated_x) - x = x + self.cross_attn(self.norm3(x), cross_c, cross_c)[0] + x = x + self.cross_attn(self.norm3(x), cross_c, cross_c, need_weights=False)[0] x = x + self.mlp2(self.norm4(x)) return x @@ -93,6 +100,7 @@ def __init__( heads=6, dropout=0.1, mlp_ratio=4.0, + use_activation_checkpointing=True, ): super().__init__() @@ -117,6 +125,7 @@ def __init__( [DiTBlock(hidden_dim, heads, dropout, mlp_ratio) for i in range(depth)] ) self.final_layer = FinalLayer(hidden_dim, output_dim) + self.use_activation_checkpointing = use_activation_checkpointing def forward(self, x, t, cross_c, neighbor_current_mask): """ @@ -150,7 +159,10 @@ def forward(self, x, t, cross_c, neighbor_current_mask): attn_mask[:, 1:] = neighbor_current_mask for block in self.blocks: - x = block(x, cross_c, t, attn_mask) + if self.training and self.use_activation_checkpointing: + x = checkpoint(block, x, cross_c, t, attn_mask, use_reentrant=False) + else: + x = block(x, cross_c, t, attn_mask) x = self.final_layer(x, t) # (B, P, output_dim) x = x.reshape(B, P, T, D) diff --git a/diffusion_planner/diffusion_planner/model/module/encoder.py b/diffusion_planner/diffusion_planner/model/module/encoder.py index 87f3d04f..f09e901d 100644 --- a/diffusion_planner/diffusion_planner/model/module/encoder.py +++ b/diffusion_planner/diffusion_planner/model/module/encoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from timm.layers import DropPath from timm.models.layers import Mlp +from torch.utils.checkpoint import checkpoint from diffusion_planner.dimensions import * from diffusion_planner.model.module.mixer import MixerBlock @@ -45,6 +46,7 @@ def __init__(self, config): self.use_ego_history = config.use_ego_history self.ego_history_dropout_rate = config.ego_history_dropout_rate self.use_turn_indicators = config.use_turn_indicators + self.use_activation_checkpointing = getattr(config, "use_activation_checkpointing", True) ego_num = 1 goal_pose_num = 1 @@ -315,7 +317,9 @@ def __init__(self, dim, heads, dropout): ) def forward(self, x, mask): - x = x + self.drop_path(self.attn(self.norm1(x), x, x, key_padding_mask=mask)[0]) + x = x + self.drop_path( + self.attn(self.norm1(x), x, x, key_padding_mask=mask, need_weights=False)[0] + ) x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -792,6 +796,9 @@ def forward(self, x, mask): mask[:, 0] = False for b in self.blocks: - x = b(x, mask) + if self.training and self.use_activation_checkpointing: + x = checkpoint(b, x, mask, use_reentrant=False) + else: + x = b(x, mask) return self.norm(x) diff --git a/diffusion_planner/diffusion_planner/train_epoch.py b/diffusion_planner/diffusion_planner/train_epoch.py index c3df0ecf..8e437ed3 100644 --- a/diffusion_planner/diffusion_planner/train_epoch.py +++ b/diffusion_planner/diffusion_planner/train_epoch.py @@ -28,6 +28,12 @@ def heading_to_cos_sin(x): def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation = None): epoch_loss = [] + use_bf16 = bool( + args.use_bf16 + and str(args.device).startswith("cuda") + and torch.cuda.is_available() + and torch.cuda.is_bf16_supported() + ) model.train() @@ -59,7 +65,12 @@ def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation # call the model optimizer.zero_grad() - loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args) + with torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=use_bf16, + ): + loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args) loss["loss"] = ( args.alpha_neighbor_loss * loss["neighbor_prediction_loss"] diff --git a/diffusion_planner/train_predictor.py b/diffusion_planner/train_predictor.py index c2b4f830..ed6ad092 100644 --- a/diffusion_planner/train_predictor.py +++ b/diffusion_planner/train_predictor.py @@ -132,6 +132,18 @@ def get_args(): parser.add_argument("--device", type=str, help="run on which device", default="cuda") parser.add_argument("--use_ema", default=True, type=boolean) + parser.add_argument( + "--use_bf16", + default=True, + type=boolean, + help="Enable bfloat16 autocast during training when supported by the GPU", + ) + parser.add_argument( + "--use_activation_checkpointing", + default=True, + type=boolean, + help="Recompute encoder/decoder activations during backward to reduce GPU memory", + ) # Model parser.add_argument("--encoder_mixer_depth", type=int, default=6) @@ -272,7 +284,7 @@ def model_training(args): model_ema = ModelEma( diffusion_planner, decay=0.999, - device=args.device, + device="cpu", ) if global_rank == 0: