Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/lmms_engine/models/qwen3/qwen3_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,15 @@ def model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
# New transformers uses config.layer_types; fall back to decoder_layer.attention_type for older versions
if hasattr(self.config, "layer_types"):
attention_type = self.config.layer_types[i]
else:
attention_type = decoder_layer.attention_type
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
16 changes: 14 additions & 2 deletions test/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import os
import shutil
import socket
import subprocess
import tempfile
import time
from functools import wraps


def find_free_port():
"""Find a free port on localhost by binding to port 0 and letting the OS assign one."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return str(s.getsockname()[1])


# Copied from https://github.com/axolotl-ai-cloud/axolotl/blob/main/tests/e2e/utils.py
def with_temp_dir(test_func):
@wraps(test_func)
Expand Down Expand Up @@ -39,7 +47,7 @@ def launch_torchrun_training(
nnodes=1,
node_rank=0,
master_addr="127.0.0.1",
master_port="8000",
master_port=None,
timeout=300,
):
"""
Expand All @@ -52,7 +60,7 @@ def launch_torchrun_training(
nnodes: Number of nodes
node_rank: Rank of this node
master_addr: Master address
master_port: Master port
master_port: Master port (auto-selects a free port if not specified)
timeout: Timeout in seconds for the training process

Returns:
Expand All @@ -63,6 +71,10 @@ def launch_torchrun_training(
if nproc_per_node == 0:
nproc_per_node = 1

if master_port is None:
master_port = find_free_port()
print(f"Auto-selected free port: {master_port}")

# Build torchrun command
cmd = [
"torchrun",
Expand Down
Loading
Loading