Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ uv run torchrun --standalone --nproc_per_node=4 \
--llm qwen2.5-1.5b-instruct \
--encoder $ECG_ENCODER or $VISION_ENCODER \
--elm mlp_llava \
--distributed
--parallel_strategy ddp
```

`--parallel_strategy ddp` uses `DistributedDataParallel` (full model replica per rank). Swap in `--parallel_strategy fsdp` for FSDP2 per-parameter sharding (each rank holds a slice of the LLM's transformer blocks); use this when the LLM is too large to replicate per GPU.

For ECG Encoders, you will have to pretrain your own ECG Encoder using [ecg_nn](https://github.com/ELM-Research/ecg_nn). We plan to release pretrained encoders soon! To load in the pretrained encoder during ELM training run the following:

```bash
Expand Down Expand Up @@ -227,7 +229,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \
--optimizer adamw \
--lr 5e-4 \
--encoder_ckpt $ENCODER_CHECKPOINT.pt \
--distributed
--parallel_strategy ddp
```

### SFT
Expand All @@ -246,7 +248,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \
--optimizer adamw \
--lr 1e-4 \
--elm_ckpt $PRETRAIN_CKPT.pt \
--distributed
--parallel_strategy ddp
```

### RL
Expand All @@ -270,7 +272,7 @@ uv run torchrun --standalone --nproc_per_node=$NPROC \
--rl_tau_pos 1.0 \
--rl_tau_neg 1.05 \
--elm_ckpt $SFT_CKPT.pt \
--distributed
--parallel_strategy ddp
```

See `scripts/st_mem_full_training.sh` for an end-to-end pretrain → SFT → RL example.
Expand Down
2 changes: 1 addition & 1 deletion scripts/st_mem_full_training.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ COMMON_FLAGS=(
--grad_clip 1.0
--llm_input_len 2048
--num_encoder_tokens 50 \
--distributed
--parallel_strategy ddp
--system_prompt "$SYSTEM_PROMPT"
--llm qwen2.5-3b-instruct
--gradient_checkpointing
Expand Down
8 changes: 4 additions & 4 deletions scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand All @@ -101,7 +101,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand All @@ -126,7 +126,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand All @@ -152,7 +152,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
6 changes: 3 additions & 3 deletions scripts/train2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand All @@ -100,7 +100,7 @@ src/main_trainer.py \
--epochs 10 \
--grad_clip 1.0 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand All @@ -123,7 +123,7 @@ src/main_trainer.py \
--llm_input_len 1024 \
--epochs 10 \
--num_workers 16 \
--distributed \
--parallel_strategy ddp \
--peft \
--torch_compile \
--wandb
Expand Down
3 changes: 2 additions & 1 deletion src/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get_args(mode: Mode) -> argparse.Namespace:
parser.add_argument("--num_workers", type=int, default=0, help="Please choose the num works for the dataloader")
parser.add_argument("--wandb", action="store_true", default=None, help="Enable logging")
parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
parser.add_argument("--distributed", action="store_true", default=None, help="Enable distributed training")
parser.add_argument("--parallel_strategy", type=str, default=None, choices=["ddp", "fsdp"],
help="Parallel strategy when launched via torchrun. ddp = DistributedDataParallel; fsdp = FSDP2 (per-parameter sharding). Omit for single-device.")
parser.add_argument("--torch_compile", action="store_true", default=None,
help="Torch compile the model (should really only be used during pretraining or large finetuning.)")
parser.add_argument("--gradient_checkpointing", action="store_true", default=False,
Expand Down
9 changes: 5 additions & 4 deletions src/dataloaders/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader
from collections.abc import Mapping, Sequence

from utils.gpu_manager import get_world_size, get_rank
from utils.parallel_context import get_parallel_context

from dataloaders.dataset_mixer import DatasetMixer

Expand Down Expand Up @@ -53,9 +53,10 @@ def get_torch_dataloader_sampler(
self,
torch_dataset,
):
if self.args.distributed:
sampler = DistributedSampler(torch_dataset, num_replicas=get_world_size(),
rank=get_rank(), seed=self.args.seed, shuffle=True)
if self.args.parallel_strategy:
ctx = get_parallel_context()
sampler = DistributedSampler(torch_dataset, num_replicas=ctx.dp_size,
rank=ctx.dp_rank, seed=self.args.seed, shuffle=True)
else:
sampler = None
return sampler
Expand Down
3 changes: 1 addition & 2 deletions src/elms/build_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def prepare_hf_siglip(self,):

def prepare_merl(self,):
from elms.ecg_encoders.merl.merl import MerlConfig, Merl
cfg = MerlConfig(distributed=self.args.distributed,
num_encoder_tokens=self.args.num_encoder_tokens)
cfg = MerlConfig(num_encoder_tokens=self.args.num_encoder_tokens)
model = Merl(cfg)
return {"encoder": model}

Expand Down
1 change: 0 additions & 1 deletion src/elms/ecg_encoders/merl/merl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class MerlConfig:
seq_len: int = 2500
lm: str = "ncbi/MedCPT-Query-Encoder"
resnet_type: str = "resnet101"
distributed: bool = False
spacial_dim: int = None
d_model: int = 2048
num_encoder_tokens: int = 1
Expand Down
2 changes: 1 addition & 1 deletion src/elms/ecg_encoders/st_mem/st_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def forward_encoder(self, x,):
return x

def get_encoder_embeddings(self, ecg_signal):
x_latents = self.forward_encoder(ecg_signal.to(torch.float32))
x_latents = self.forward_encoder(ecg_signal.to(next(self.parameters()).dtype))
out = rearrange(x_latents, 'b c n d -> b (c n) d')
out = out.transpose(1, 2)
out = self.avgpool(out)
Expand Down
3 changes: 3 additions & 0 deletions src/elms/llm_encoders/base_elf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def train(self, mode: bool = True):
module.train(mode if name in self.update else False)
return self

def fsdp_wrap_modules(self):
return self.llm.fsdp_wrap_modules()

def forward(self, elm_input_ids, encoder_tokenizer_out,
elm_attention_mask, elm_labels, signal_id_indices):
projected_embeds = self.get_projections(**encoder_tokenizer_out)
Expand Down
3 changes: 3 additions & 0 deletions src/elms/llm_encoders/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def train(self, mode: bool = True):
module.train(mode if name in self.update else False)
return self

def fsdp_wrap_modules(self):
return self.llm.fsdp_wrap_modules()

def forward(self, elm_input_ids, elm_attention_mask, elm_labels, signal_id_indices,
encoder_tokenizer_out):
projected_embeds = self.get_projections(encoder_tokenizer_out)
Expand Down
19 changes: 19 additions & 0 deletions src/elms/llms/_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Shared helpers for FSDP wrapping of LLM wrapper modules."""

from torch import nn


def get_decoder_layers(hf_model: nn.Module) -> list[nn.Module]:
"""Return the transformer decoder block ModuleList inside an HF causal LM.

Handles PEFT (PeftModel -> base_model -> model -> model.layers) and the
bare HF (model.model.layers) layouts.
"""
node = hf_model
if hasattr(node, "base_model") and hasattr(node.base_model, "model"):
node = node.base_model.model
if hasattr(node, "model") and hasattr(node.model, "layers"):
return list(node.model.layers)
if hasattr(node, "layers"):
return list(node.layers)
raise RuntimeError(f"Could not locate decoder layers inside {type(hf_model).__name__}")
7 changes: 6 additions & 1 deletion src/elms/llms/gemma2/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask,
inputs_embeds = elm_inputs_embeds,
attention_mask = elm_attention_mask,
labels = elm_labels,
output_hidden_states = self.output_hidden_states)
output_hidden_states = self.output_hidden_states,
use_cache = False)

def get_llm_embeddings(self, elm_input_ids):
out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device))
return out

def fsdp_wrap_modules(self):
from elms.llms._wrap import get_decoder_layers
return get_decoder_layers(self.llm)

def generate(self, elm_input_ids, elm_attention_mask,
elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs):
return self.llm.generate(
Expand Down
7 changes: 6 additions & 1 deletion src/elms/llms/llama3/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask,
inputs_embeds = elm_inputs_embeds,
attention_mask = elm_attention_mask,
labels = elm_labels,
output_hidden_states = self.output_hidden_states)
output_hidden_states = self.output_hidden_states,
use_cache = False)

def get_llm_embeddings(self, elm_input_ids):
out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device))
return out

def fsdp_wrap_modules(self):
from elms.llms._wrap import get_decoder_layers
return get_decoder_layers(self.llm)

def generate(self, elm_input_ids, elm_attention_mask,
elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs):
return self.llm.generate(
Expand Down
7 changes: 6 additions & 1 deletion src/elms/llms/qwen25/qwen25.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ def forward(self, elm_input_ids, elm_attention_mask,
inputs_embeds = elm_inputs_embeds,
attention_mask = elm_attention_mask,
labels = elm_labels,
output_hidden_states = self.output_hidden_states)
output_hidden_states = self.output_hidden_states,
use_cache = False)

def get_llm_embeddings(self, elm_input_ids):
out = self.llm.get_input_embeddings()(elm_input_ids.to(self.llm.device))
return out

def fsdp_wrap_modules(self):
from elms.llms._wrap import get_decoder_layers
return get_decoder_layers(self.llm)

def generate(self, elm_input_ids, elm_attention_mask,
elm_inputs_embeds= None, max_new_tokens=128, **gen_kwargs):
return self.llm.generate(
Expand Down
15 changes: 10 additions & 5 deletions src/main_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def main():
args.mode = mode
args.task = "train"

if args.distributed:
init_dist()
if args.parallel_strategy:
init_dist(args.parallel_strategy)

gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -70,12 +70,17 @@ def main():
for epoch in range(start_epoch, args.epochs):
train_result = runner(elm, optimizer, dataloader, epoch, args, checkpoint_manager)
should_stop = False
should_save = False
if checkpoint_manager and is_main():
if checkpoint_manager.save_epoch(train_result["average_loss"]):
checkpoint_manager.save_checkpoint(elm, optimizer, epoch, -1, is_best=True, prefix="epoch_")
should_save = checkpoint_manager.save_epoch(train_result["average_loss"])
if args.early_stopping and checkpoint_manager.stop_early():
print(f"Early stopping at epoch {epoch}")
should_stop = True
# Decision is rank-0-only (best_loss tracking lives there); save
# itself must be collective for FSDP get_model_state_dict.
should_save = broadcast_value(should_save, src=0)
if checkpoint_manager and should_save:
checkpoint_manager.save_checkpoint(elm, optimizer, epoch, -1, is_best=True, prefix="epoch_")
should_stop = broadcast_value(should_stop, src=0)
if should_stop:
break
Expand All @@ -84,7 +89,7 @@ def main():
with open(f"{run_folder}/DONE.txt", "w") as _:
pass
finally:
if args.distributed:
if args.parallel_strategy:
cleanup()
if is_main() and args.wandb:
cleanup_wandb()
Expand Down
Loading