Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def __init__(self, config):
hidden_dim=config.hidden_dim,
heads=config.num_heads,
dropout=dpr,
use_activation_checkpointing=getattr(config, "use_activation_checkpointing", True),
)
self.turn_indicator_predictor = nn.Linear(
2 * (self._future_len // 10) + config.hidden_dim, TURN_INDICATOR_OUTPUT_DIM
Expand Down
18 changes: 15 additions & 3 deletions diffusion_planner/diffusion_planner/model/module/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
from timm.models.layers import Mlp
from torch.utils.checkpoint import checkpoint


def modulate(x, shift, scale):
Expand Down Expand Up @@ -43,13 +44,19 @@ def forward(self, x, cross_c, y, attn_mask):
x = (
x
+ gate_msa
* self.attn(modulated_x, modulated_x, modulated_x, key_padding_mask=attn_mask)[0]
* self.attn(
modulated_x,
modulated_x,
modulated_x,
key_padding_mask=attn_mask,
need_weights=False,
)[0]
)

modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + gate_mlp * self.mlp1(modulated_x)

x = x + self.cross_attn(self.norm3(x), cross_c, cross_c)[0]
x = x + self.cross_attn(self.norm3(x), cross_c, cross_c, need_weights=False)[0]
x = x + self.mlp2(self.norm4(x))

return x
Expand Down Expand Up @@ -93,6 +100,7 @@ def __init__(
heads=6,
dropout=0.1,
mlp_ratio=4.0,
use_activation_checkpointing=True,
):
super().__init__()

Expand All @@ -117,6 +125,7 @@ def __init__(
[DiTBlock(hidden_dim, heads, dropout, mlp_ratio) for i in range(depth)]
)
self.final_layer = FinalLayer(hidden_dim, output_dim)
self.use_activation_checkpointing = use_activation_checkpointing

def forward(self, x, t, cross_c, neighbor_current_mask):
"""
Expand Down Expand Up @@ -150,7 +159,10 @@ def forward(self, x, t, cross_c, neighbor_current_mask):
attn_mask[:, 1:] = neighbor_current_mask

for block in self.blocks:
x = block(x, cross_c, t, attn_mask)
if self.training and self.use_activation_checkpointing:
x = checkpoint(block, x, cross_c, t, attn_mask, use_reentrant=False)
else:
x = block(x, cross_c, t, attn_mask)

x = self.final_layer(x, t) # (B, P, output_dim)
x = x.reshape(B, P, T, D)
Expand Down
11 changes: 9 additions & 2 deletions diffusion_planner/diffusion_planner/model/module/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
from timm.layers import DropPath
from timm.models.layers import Mlp
from torch.utils.checkpoint import checkpoint

from diffusion_planner.dimensions import *
from diffusion_planner.model.module.mixer import MixerBlock
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self, config):
self.use_ego_history = config.use_ego_history
self.ego_history_dropout_rate = config.ego_history_dropout_rate
self.use_turn_indicators = config.use_turn_indicators
self.use_activation_checkpointing = getattr(config, "use_activation_checkpointing", True)

ego_num = 1
goal_pose_num = 1
Expand Down Expand Up @@ -315,7 +317,9 @@ def __init__(self, dim, heads, dropout):
)

def forward(self, x, mask):
x = x + self.drop_path(self.attn(self.norm1(x), x, x, key_padding_mask=mask)[0])
x = x + self.drop_path(
self.attn(self.norm1(x), x, x, key_padding_mask=mask, need_weights=False)[0]
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x

Expand Down Expand Up @@ -792,6 +796,9 @@ def forward(self, x, mask):
mask[:, 0] = False

for b in self.blocks:
x = b(x, mask)
if self.training and self.use_activation_checkpointing:
x = checkpoint(b, x, mask, use_reentrant=False)
else:
x = b(x, mask)

return self.norm(x)
13 changes: 12 additions & 1 deletion diffusion_planner/diffusion_planner/train_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def heading_to_cos_sin(x):

def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation = None):
epoch_loss = []
use_bf16 = bool(
args.use_bf16
and str(args.device).startswith("cuda")
and torch.cuda.is_available()
and torch.cuda.is_bf16_supported()
)

model.train()

Expand Down Expand Up @@ -59,7 +65,12 @@ def train_epoch(data_loader, model, optimizer, args, ema, aug: StatePerturbation
# call the model
optimizer.zero_grad()

loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args)
with torch.autocast(
device_type="cuda",
dtype=torch.bfloat16,
enabled=use_bf16,
):
loss = compute_training_loss(model, inputs, (ego_future, neighbors_future, mask), args)

loss["loss"] = (
args.alpha_neighbor_loss * loss["neighbor_prediction_loss"]
Expand Down
14 changes: 13 additions & 1 deletion diffusion_planner/train_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ def get_args():
parser.add_argument("--device", type=str, help="run on which device", default="cuda")

parser.add_argument("--use_ema", default=True, type=boolean)
parser.add_argument(
"--use_bf16",
default=True,
type=boolean,
help="Enable bfloat16 autocast during training when supported by the GPU",
)
parser.add_argument(
"--use_activation_checkpointing",
default=True,
type=boolean,
help="Recompute encoder/decoder activations during backward to reduce GPU memory",
)

# Model
parser.add_argument("--encoder_mixer_depth", type=int, default=6)
Expand Down Expand Up @@ -272,7 +284,7 @@ def model_training(args):
model_ema = ModelEma(
diffusion_planner,
decay=0.999,
device=args.device,
device="cpu",
)

if global_rank == 0:
Expand Down