From 13f907b89aad66f30a8a664a7b92d0c24e65713b Mon Sep 17 00:00:00 2001 From: deepak-pradhan Date: Tue, 16 Dec 2025 20:20:19 -0500 Subject: [PATCH 1/2] fix: disable xformers for custom attention bias in GRPO training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit xformers with GQA (Grouped Query Attention) switches to 5D tensor format during gradient checkpointing when requires_grad=False. The cutlass backend doesn't support custom attention bias with 5D tensors, causing: "RuntimeError: Bias expected in BMHK format" Solution: Temporarily disable xformers during training to force the SDPA path, which always uses 4D tensors and properly supports custom attention bias for trajectory group/parent masking. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/art/unsloth/train.py | 48 +++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index ddacaafd..f48fe2b7 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -24,6 +24,13 @@ async def train( trainer: "GRPOTrainer", results_queue: asyncio.Queue[dict[str, float]], ) -> None: + # Disable xformers to force SDPA path for custom attention bias support + # xformers with GQA (5D tensors) doesn't support custom bias during gradient checkpointing + import unsloth.models.mistral as mistral_module + + _has_xformers = getattr(mistral_module, "HAS_XFORMERS", False) + mistral_module.HAS_XFORMERS = False + _compute_loss = trainer.compute_loss _log = trainer.log trainer.compute_loss = get_compute_loss_fn(trainer) @@ -41,6 +48,7 @@ async def train( finally: trainer.compute_loss = _compute_loss trainer.log = _log + mistral_module.HAS_XFORMERS = _has_xformers def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]: @@ -97,7 +105,15 @@ def compute_loss( dtype_for_autocasting = torch.bfloat16 batch_size, seq_len = inputs["tokens"].size() - attn_bias = calculate_attn_bias( + # Get attention head counts from model config for xformers format + # Training mode (requires_grad=True): 4D [B, n_heads, S, S] + # Inference mode (requires_grad=False): 5D [B, n_kv_heads, n_groups, S, S] + num_attention_heads = trainer.model.config.num_attention_heads + num_key_value_heads = getattr( + trainer.model.config, "num_key_value_heads", num_attention_heads + ) + # Create base 3D mask, will be expanded in calculate_logprobs + attn_bias_3d = calculate_attn_bias( batch_size, seq_len, trainer.accelerator.device, @@ -127,7 +143,7 @@ def compute_loss( dtype_for_autocasting, trainer, inputs["tokens"], - attn_bias, + attn_bias_3d, forward_kwargs, next_input_ids, lm_head_t, @@ -135,6 +151,8 @@ def compute_loss( inference_mode=return_new_logprobs, no_grad=return_new_logprobs, reference_logprobs=False, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, ) if return_new_logprobs: return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0) @@ -143,7 +161,7 @@ def compute_loss( dtype_for_autocasting, trainer, inputs["tokens"], - attn_bias, + attn_bias_3d, forward_kwargs, next_input_ids, lm_head_t, @@ -151,10 +169,12 @@ def compute_loss( inference_mode=True, no_grad=False, reference_logprobs=True, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, ) else: ref_logprobs = None - del attn_bias + del attn_bias_3d loss = loss_fn( inputs, @@ -204,6 +224,11 @@ def calculate_attn_bias( parent_ids: torch.Tensor, dtype: torch.dtype, ) -> torch.Tensor: + """Calculate base 3D attention bias [B, S, S] from group/parent IDs. + + The bias is expanded to the appropriate format (4D or 5D) in calculate_logprobs + based on whether we're in inference mode or training mode. + """ mask = calculate_mask(batch_size, seq_len, device, group_ids, parent_ids) # Use the same dtype as autocast to save memory and avoid dtype conversions attn_bias = torch.where( @@ -260,9 +285,21 @@ def calculate_logprobs( inference_mode: bool, no_grad: bool, reference_logprobs: bool, + num_attention_heads: int, + num_key_value_heads: int, ) -> tuple[ torch.Tensor, torch.Tensor ]: # Returns (log_probs, entropy) both shape [B, S] + # Expand 3D causal_mask [B, S, S] to 4D [B, n_heads, S, S] for SDPA + # We disable xformers in the train() function to force the SDPA path + # because xformers with GQA doesn't support custom bias during gradient checkpointing + batch_size, seq_len, _ = causal_mask.shape + expanded_mask = ( + causal_mask.unsqueeze(1) + .expand(batch_size, num_attention_heads, seq_len, seq_len) + .contiguous() + ) + with ( torch.inference_mode() if inference_mode else nullcontext(), torch.no_grad() if no_grad else nullcontext(), @@ -276,8 +313,9 @@ def calculate_logprobs( torch.amp.autocast_mode.autocast(device_type="cuda", dtype=dtype_for_autocast), ): hidden_states = trainer.model( # type: ignore - input_ids=input_ids, causal_mask=causal_mask, **forward_kwargs + input_ids=input_ids, causal_mask=expanded_mask, **forward_kwargs ).logits # Shape [B, S, H] + del expanded_mask return _calculate_logprobs(lm_head_t, hidden_states, next_input_ids, chunk_size) From 364016212751bd5bb04f86b83fc7e88f1953b897 Mon Sep 17 00:00:00 2001 From: deepak-pradhan Date: Wed, 17 Dec 2025 18:23:29 -0500 Subject: [PATCH 2/2] fix: add type annotations to satisfy pyright MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add type: ignore comments and explicit int() casts for model config attributes to pass pyright type checking in CI. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/art/unsloth/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index f48fe2b7..edd1e2d0 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -108,9 +108,10 @@ def compute_loss( # Get attention head counts from model config for xformers format # Training mode (requires_grad=True): 4D [B, n_heads, S, S] # Inference mode (requires_grad=False): 5D [B, n_kv_heads, n_groups, S, S] - num_attention_heads = trainer.model.config.num_attention_heads - num_key_value_heads = getattr( - trainer.model.config, "num_key_value_heads", num_attention_heads + model_config = trainer.model.config # type: ignore[union-attr] + num_attention_heads = int(model_config.num_attention_heads) # type: ignore[union-attr] + num_key_value_heads = int( + getattr(model_config, "num_key_value_heads", num_attention_heads) ) # Create base 3D mask, will be expanded in calculate_logprobs attn_bias_3d = calculate_attn_bias(