From 65e9e4cb5dedad015fc9895b53b3c1aad9001bb9 Mon Sep 17 00:00:00 2001 From: Yukihiro Saito Date: Tue, 31 Mar 2026 13:31:55 +0900 Subject: [PATCH 1/4] Enable bf16 autocast for training --- diffusion_planner/diffusion_planner/train_epoch.py | 13 ++++++++++++- diffusion_planner/train_predictor.py | 6 ++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/diffusion_planner/diffusion_planner/train_epoch.py b/diffusion_planner/diffusion_planner/train_epoch.py index c3df0ecfd..8e437ed3e 100644 --- a/diffusion_planner/diffusion_planner/train_epoch.py +++ b/diffusion_planner/diffusion_planner/train_epoch.py @@ -28,6 +28,12 @@ def heading_to_cos_sin(x): def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation = None): epoch_loss = [] + use_bf16 = bool( + args.use_bf16 + and str(args.device).startswith("cuda") + and torch.cuda.is_available() + and torch.cuda.is_bf16_supported() + ) model.train() @@ -59,7 +65,12 @@ def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation # call the model optimizer.zero_grad() - loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args) + with torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=use_bf16, + ): + loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args) loss["loss"] = ( args.alpha_neighbor_loss * loss["neighbor_prediction_loss"] diff --git a/diffusion_planner/train_predictor.py b/diffusion_planner/train_predictor.py index c2b4f8300..6be7a041c 100644 --- a/diffusion_planner/train_predictor.py +++ b/diffusion_planner/train_predictor.py @@ -132,6 +132,12 @@ def get_args(): parser.add_argument("--device", type=str, help="run on which device", default="cuda") parser.add_argument("--use_ema", default=True, type=boolean) + parser.add_argument( + "--use_bf16", + default=True, + type=boolean, + help="Enable bfloat16 autocast during training when supported by the GPU", + ) # Model parser.add_argument("--encoder_mixer_depth", type=int, default=6) From 9887bfe4c2ce1f64f0d9396754a94d9c05a1d201 Mon Sep 17 00:00:00 2001 From: Yukihiro Saito Date: Tue, 31 Mar 2026 13:32:18 +0900 Subject: [PATCH 2/4] Disable unused attention weight outputs --- .../diffusion_planner/model/module/dit.py | 10 ++++++++-- .../diffusion_planner/model/module/encoder.py | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/diffusion_planner/diffusion_planner/model/module/dit.py b/diffusion_planner/diffusion_planner/model/module/dit.py index 233c926e5..0d63f7708 100644 --- a/diffusion_planner/diffusion_planner/model/module/dit.py +++ b/diffusion_planner/diffusion_planner/model/module/dit.py @@ -43,13 +43,19 @@ def forward(self, x, cross_c, y, attn_mask): x = ( x + gate_msa - * self.attn(modulated_x, modulated_x, modulated_x, key_padding_mask=attn_mask)[0] + * self.attn( + modulated_x, + modulated_x, + modulated_x, + key_padding_mask=attn_mask, + need_weights=False, + )[0] ) modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp) x = x + gate_mlp * self.mlp1(modulated_x) - x = x + self.cross_attn(self.norm3(x), cross_c, cross_c)[0] + x = x + self.cross_attn(self.norm3(x), cross_c, cross_c, need_weights=False)[0] x = x + self.mlp2(self.norm4(x)) return x diff --git a/diffusion_planner/diffusion_planner/model/module/encoder.py b/diffusion_planner/diffusion_planner/model/module/encoder.py index 87f3d04fa..41d7dc266 100644 --- a/diffusion_planner/diffusion_planner/model/module/encoder.py +++ b/diffusion_planner/diffusion_planner/model/module/encoder.py @@ -315,7 +315,9 @@ def __init__(self, dim, heads, dropout): ) def forward(self, x, mask): - x = x + self.drop_path(self.attn(self.norm1(x), x, x, key_padding_mask=mask)[0]) + x = x + self.drop_path( + self.attn(self.norm1(x), x, x, key_padding_mask=mask, need_weights=False)[0] + ) x = x + self.drop_path(self.mlp(self.norm2(x))) return x From 528e9b2217f8d9b832f98081338633b454886ba9 Mon Sep 17 00:00:00 2001 From: Yukihiro Saito Date: Tue, 31 Mar 2026 13:32:42 +0900 Subject: [PATCH 3/4] Store EMA weights on CPU --- diffusion_planner/train_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion_planner/train_predictor.py b/diffusion_planner/train_predictor.py index 6be7a041c..017f25d50 100644 --- a/diffusion_planner/train_predictor.py +++ b/diffusion_planner/train_predictor.py @@ -278,7 +278,7 @@ def model_training(args): model_ema = ModelEma( diffusion_planner, decay=0.999, - device=args.device, + device="cpu", ) if global_rank == 0: From a213697b099cf6d96155341df5b4bc154b8d52f3 Mon Sep 17 00:00:00 2001 From: Yukihiro Saito Date: Tue, 31 Mar 2026 13:33:35 +0900 Subject: [PATCH 4/4] Add activation checkpointing to transformer blocks --- .../diffusion_planner/model/module/decoder.py | 1 + diffusion_planner/diffusion_planner/model/module/dit.py | 8 +++++++- .../diffusion_planner/model/module/encoder.py | 7 ++++++- diffusion_planner/train_predictor.py | 6 ++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/diffusion_planner/diffusion_planner/model/module/decoder.py b/diffusion_planner/diffusion_planner/model/module/decoder.py index 4f4077e1b..52b7c385c 100644 --- a/diffusion_planner/diffusion_planner/model/module/decoder.py +++ b/diffusion_planner/diffusion_planner/model/module/decoder.py @@ -276,6 +276,7 @@ def __init__(self, config): hidden_dim=config.hidden_dim, heads=config.num_heads, dropout=dpr, + use_activation_checkpointing=getattr(config, "use_activation_checkpointing", True), ) self.turn_indicator_predictor = nn.Linear( 2 * (self._future_len // 10) + config.hidden_dim, TURN_INDICATOR_OUTPUT_DIM diff --git a/diffusion_planner/diffusion_planner/model/module/dit.py b/diffusion_planner/diffusion_planner/model/module/dit.py index 0d63f7708..4bdac8f72 100644 --- a/diffusion_planner/diffusion_planner/model/module/dit.py +++ b/diffusion_planner/diffusion_planner/model/module/dit.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from timm.models.layers import Mlp +from torch.utils.checkpoint import checkpoint def modulate(x, shift, scale): @@ -99,6 +100,7 @@ def __init__( heads=6, dropout=0.1, mlp_ratio=4.0, + use_activation_checkpointing=True, ): super().__init__() @@ -123,6 +125,7 @@ def __init__( [DiTBlock(hidden_dim, heads, dropout, mlp_ratio) for i in range(depth)] ) self.final_layer = FinalLayer(hidden_dim, output_dim) + self.use_activation_checkpointing = use_activation_checkpointing def forward(self, x, t, cross_c, neighbor_current_mask): """ @@ -156,7 +159,10 @@ def forward(self, x, t, cross_c, neighbor_current_mask): attn_mask[:, 1:] = neighbor_current_mask for block in self.blocks: - x = block(x, cross_c, t, attn_mask) + if self.training and self.use_activation_checkpointing: + x = checkpoint(block, x, cross_c, t, attn_mask, use_reentrant=False) + else: + x = block(x, cross_c, t, attn_mask) x = self.final_layer(x, t) # (B, P, output_dim) x = x.reshape(B, P, T, D) diff --git a/diffusion_planner/diffusion_planner/model/module/encoder.py b/diffusion_planner/diffusion_planner/model/module/encoder.py index 41d7dc266..f09e901de 100644 --- a/diffusion_planner/diffusion_planner/model/module/encoder.py +++ b/diffusion_planner/diffusion_planner/model/module/encoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from timm.layers import DropPath from timm.models.layers import Mlp +from torch.utils.checkpoint import checkpoint from diffusion_planner.dimensions import * from diffusion_planner.model.module.mixer import MixerBlock @@ -45,6 +46,7 @@ def __init__(self, config): self.use_ego_history = config.use_ego_history self.ego_history_dropout_rate = config.ego_history_dropout_rate self.use_turn_indicators = config.use_turn_indicators + self.use_activation_checkpointing = getattr(config, "use_activation_checkpointing", True) ego_num = 1 goal_pose_num = 1 @@ -794,6 +796,9 @@ def forward(self, x, mask): mask[:, 0] = False for b in self.blocks: - x = b(x, mask) + if self.training and self.use_activation_checkpointing: + x = checkpoint(b, x, mask, use_reentrant=False) + else: + x = b(x, mask) return self.norm(x) diff --git a/diffusion_planner/train_predictor.py b/diffusion_planner/train_predictor.py index 017f25d50..ed6ad0920 100644 --- a/diffusion_planner/train_predictor.py +++ b/diffusion_planner/train_predictor.py @@ -138,6 +138,12 @@ def get_args(): type=boolean, help="Enable bfloat16 autocast during training when supported by the GPU", ) + parser.add_argument( + "--use_activation_checkpointing", + default=True, + type=boolean, + help="Recompute encoder/decoder activations during backward to reduce GPU memory", + ) # Model parser.add_argument("--encoder_mixer_depth", type=int, default=6)