diff --git a/miles_plugins/mbridge/xllm.py b/miles_plugins/mbridge/xllm.py index 8859e94013..7203fdcb99 100644 --- a/miles_plugins/mbridge/xllm.py +++ b/miles_plugins/mbridge/xllm.py @@ -23,6 +23,13 @@ def _has_moe(self): return (getattr(self.hf_config, "num_experts", 0) or 0) > 0 def _build_config(self): + head_dim = getattr( + self.hf_config, + "head_dim", + self.hf_config.hidden_size // self.hf_config.num_attention_heads, + ) + rope_head_dim = getattr(self.hf_config, "rope_head_dim", head_dim) + config_kwargs = dict( use_cpu_initialization=False, persist_layer_norm=True, @@ -31,6 +38,8 @@ def _build_config(self): qk_layernorm=False, add_qkv_bias=False, add_bias_linear=False, + rotary_percent=rope_head_dim / head_dim, + xllm_partial_rope_layout=rope_head_dim * 2 == head_dim, ) if self._has_moe(): diff --git a/scripts/models/xllm-375B.sh b/scripts/models/xllm-375B.sh index 652b988022..049f5070e4 100644 --- a/scripts/models/xllm-375B.sh +++ b/scripts/models/xllm-375B.sh @@ -16,6 +16,7 @@ MODEL_ARGS=( --normalization RMSNorm --position-embedding-type rope --rotary-percent 0.5 + --xllm-partial-rope-layout --rotary-base 500000 --swiglu --untie-embeddings-and-output-weights