Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions src/art/unsloth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ async def train(
trainer: "GRPOTrainer",
results_queue: asyncio.Queue[dict[str, float]],
) -> None:
# Disable xformers to force SDPA path for custom attention bias support
# xformers with GQA (5D tensors) doesn't support custom bias during gradient checkpointing
import unsloth.models.mistral as mistral_module

_has_xformers = getattr(mistral_module, "HAS_XFORMERS", False)
mistral_module.HAS_XFORMERS = False

_compute_loss = trainer.compute_loss
_log = trainer.log
trainer.compute_loss = get_compute_loss_fn(trainer)
Expand All @@ -41,6 +48,7 @@ async def train(
finally:
trainer.compute_loss = _compute_loss
trainer.log = _log
mistral_module.HAS_XFORMERS = _has_xformers


def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]:
Expand Down Expand Up @@ -97,7 +105,16 @@ def compute_loss(
dtype_for_autocasting = torch.bfloat16

batch_size, seq_len = inputs["tokens"].size()
attn_bias = calculate_attn_bias(
# Get attention head counts from model config for xformers format
# Training mode (requires_grad=True): 4D [B, n_heads, S, S]
# Inference mode (requires_grad=False): 5D [B, n_kv_heads, n_groups, S, S]
model_config = trainer.model.config # type: ignore[union-attr]
num_attention_heads = int(model_config.num_attention_heads) # type: ignore[union-attr]
num_key_value_heads = int(
getattr(model_config, "num_key_value_heads", num_attention_heads)
)
# Create base 3D mask, will be expanded in calculate_logprobs
attn_bias_3d = calculate_attn_bias(
batch_size,
seq_len,
trainer.accelerator.device,
Expand Down Expand Up @@ -127,14 +144,16 @@ def compute_loss(
dtype_for_autocasting,
trainer,
inputs["tokens"],
attn_bias,
attn_bias_3d,
forward_kwargs,
next_input_ids,
lm_head_t,
chunk_size=chunk_size,
inference_mode=return_new_logprobs,
no_grad=return_new_logprobs,
reference_logprobs=False,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
)
if return_new_logprobs:
return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0)
Expand All @@ -143,18 +162,20 @@ def compute_loss(
dtype_for_autocasting,
trainer,
inputs["tokens"],
attn_bias,
attn_bias_3d,
forward_kwargs,
next_input_ids,
lm_head_t,
chunk_size=chunk_size,
inference_mode=True,
no_grad=False,
reference_logprobs=True,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
)
else:
ref_logprobs = None
del attn_bias
del attn_bias_3d

loss = loss_fn(
inputs,
Expand Down Expand Up @@ -204,6 +225,11 @@ def calculate_attn_bias(
parent_ids: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Calculate base 3D attention bias [B, S, S] from group/parent IDs.

The bias is expanded to the appropriate format (4D or 5D) in calculate_logprobs
based on whether we're in inference mode or training mode.
"""
mask = calculate_mask(batch_size, seq_len, device, group_ids, parent_ids)
# Use the same dtype as autocast to save memory and avoid dtype conversions
attn_bias = torch.where(
Expand Down Expand Up @@ -260,9 +286,21 @@ def calculate_logprobs(
inference_mode: bool,
no_grad: bool,
reference_logprobs: bool,
num_attention_heads: int,
num_key_value_heads: int,
) -> tuple[
torch.Tensor, torch.Tensor
]: # Returns (log_probs, entropy) both shape [B, S]
# Expand 3D causal_mask [B, S, S] to 4D [B, n_heads, S, S] for SDPA
# We disable xformers in the train() function to force the SDPA path
# because xformers with GQA doesn't support custom bias during gradient checkpointing
batch_size, seq_len, _ = causal_mask.shape
expanded_mask = (
causal_mask.unsqueeze(1)
.expand(batch_size, num_attention_heads, seq_len, seq_len)
.contiguous()
)

with (
torch.inference_mode() if inference_mode else nullcontext(),
torch.no_grad() if no_grad else nullcontext(),
Expand All @@ -276,8 +314,9 @@ def calculate_logprobs(
torch.amp.autocast_mode.autocast(device_type="cuda", dtype=dtype_for_autocast),
):
hidden_states = trainer.model( # type: ignore
input_ids=input_ids, causal_mask=causal_mask, **forward_kwargs
input_ids=input_ids, causal_mask=expanded_mask, **forward_kwargs
).logits # Shape [B, S, H]
del expanded_mask
return _calculate_logprobs(lm_head_t, hidden_states, next_input_ids, chunk_size)


Expand Down