diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index ddacaafd..edd1e2d0 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -24,6 +24,13 @@ async def train( trainer: "GRPOTrainer", results_queue: asyncio.Queue[dict[str, float]], ) -> None: + # Disable xformers to force SDPA path for custom attention bias support + # xformers with GQA (5D tensors) doesn't support custom bias during gradient checkpointing + import unsloth.models.mistral as mistral_module + + _has_xformers = getattr(mistral_module, "HAS_XFORMERS", False) + mistral_module.HAS_XFORMERS = False + _compute_loss = trainer.compute_loss _log = trainer.log trainer.compute_loss = get_compute_loss_fn(trainer) @@ -41,6 +48,7 @@ async def train( finally: trainer.compute_loss = _compute_loss trainer.log = _log + mistral_module.HAS_XFORMERS = _has_xformers def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]: @@ -97,7 +105,16 @@ def compute_loss( dtype_for_autocasting = torch.bfloat16 batch_size, seq_len = inputs["tokens"].size() - attn_bias = calculate_attn_bias( + # Get attention head counts from model config for xformers format + # Training mode (requires_grad=True): 4D [B, n_heads, S, S] + # Inference mode (requires_grad=False): 5D [B, n_kv_heads, n_groups, S, S] + model_config = trainer.model.config # type: ignore[union-attr] + num_attention_heads = int(model_config.num_attention_heads) # type: ignore[union-attr] + num_key_value_heads = int( + getattr(model_config, "num_key_value_heads", num_attention_heads) + ) + # Create base 3D mask, will be expanded in calculate_logprobs + attn_bias_3d = calculate_attn_bias( batch_size, seq_len, trainer.accelerator.device, @@ -127,7 +144,7 @@ def compute_loss( dtype_for_autocasting, trainer, inputs["tokens"], - attn_bias, + attn_bias_3d, forward_kwargs, next_input_ids, lm_head_t, @@ -135,6 +152,8 @@ def compute_loss( inference_mode=return_new_logprobs, no_grad=return_new_logprobs, reference_logprobs=False, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, ) if return_new_logprobs: return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0) @@ -143,7 +162,7 @@ def compute_loss( dtype_for_autocasting, trainer, inputs["tokens"], - attn_bias, + attn_bias_3d, forward_kwargs, next_input_ids, lm_head_t, @@ -151,10 +170,12 @@ def compute_loss( inference_mode=True, no_grad=False, reference_logprobs=True, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, ) else: ref_logprobs = None - del attn_bias + del attn_bias_3d loss = loss_fn( inputs, @@ -204,6 +225,11 @@ def calculate_attn_bias( parent_ids: torch.Tensor, dtype: torch.dtype, ) -> torch.Tensor: + """Calculate base 3D attention bias [B, S, S] from group/parent IDs. + + The bias is expanded to the appropriate format (4D or 5D) in calculate_logprobs + based on whether we're in inference mode or training mode. + """ mask = calculate_mask(batch_size, seq_len, device, group_ids, parent_ids) # Use the same dtype as autocast to save memory and avoid dtype conversions attn_bias = torch.where( @@ -260,9 +286,21 @@ def calculate_logprobs( inference_mode: bool, no_grad: bool, reference_logprobs: bool, + num_attention_heads: int, + num_key_value_heads: int, ) -> tuple[ torch.Tensor, torch.Tensor ]: # Returns (log_probs, entropy) both shape [B, S] + # Expand 3D causal_mask [B, S, S] to 4D [B, n_heads, S, S] for SDPA + # We disable xformers in the train() function to force the SDPA path + # because xformers with GQA doesn't support custom bias during gradient checkpointing + batch_size, seq_len, _ = causal_mask.shape + expanded_mask = ( + causal_mask.unsqueeze(1) + .expand(batch_size, num_attention_heads, seq_len, seq_len) + .contiguous() + ) + with ( torch.inference_mode() if inference_mode else nullcontext(), torch.no_grad() if no_grad else nullcontext(), @@ -276,8 +314,9 @@ def calculate_logprobs( torch.amp.autocast_mode.autocast(device_type="cuda", dtype=dtype_for_autocast), ): hidden_states = trainer.model( # type: ignore - input_ids=input_ids, causal_mask=causal_mask, **forward_kwargs + input_ids=input_ids, causal_mask=expanded_mask, **forward_kwargs ).logits # Shape [B, S, H] + del expanded_mask return _calculate_logprobs(lm_head_t, hidden_states, next_input_ids, chunk_size)