Skip to content

feat: add max_tokens_per_microbatch config for token-based micro-batching#1477

Open
erictang000 wants to merge 1 commit intoNovaSky-AI:mainfrom
erictang000:max_tokens_per_microbatch
Open

feat: add max_tokens_per_microbatch config for token-based micro-batching#1477
erictang000 wants to merge 1 commit intoNovaSky-AI:mainfrom
erictang000:max_tokens_per_microbatch

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented Apr 8, 2026

Adds a new max_tokens_per_microbatch config option that enables token-based micro-batching instead of fixed sample-count batching. When enabled, samples are bin-packed into microbatches based on their actual token counts (from attention_mask), which reduces padding waste when sequence lengths vary widely.

Key changes:

  • Add max_tokens_per_microbatch config field (default -1 = disabled)
  • Add TokenBasedBatchIterator with balanced bin-packing algorithm
  • Add SampleBasedBatchIterator (refactored from old BatchIterator)
  • Update FSDP Worker.forward(), PolicyWorkerBase.forward_backward(), and CriticWorkerBase.forward_backward() to use token-based batching
  • Update Megatron forward and forward_backward paths with padding to uniform micro_batch_size (required by Megatron's pipeline schedule)
  • Add comprehensive tests for both FSDP and Megatron backends

Open with Devin

…hing

Adds a new `max_tokens_per_microbatch` config option that enables token-based
micro-batching instead of fixed sample-count batching. When enabled, samples
are bin-packed into microbatches based on their actual token counts (from
attention_mask), which reduces padding waste when sequence lengths vary widely.

Key changes:
- Add `max_tokens_per_microbatch` config field (default -1 = disabled)
- Add `TokenBasedBatchIterator` with balanced bin-packing algorithm
- Add `SampleBasedBatchIterator` (refactored from old `BatchIterator`)
- Update FSDP Worker.forward(), PolicyWorkerBase.forward_backward(),
  and CriticWorkerBase.forward_backward() to use token-based batching
- Update Megatron forward and forward_backward paths with padding to
  uniform micro_batch_size (required by Megatron's pipeline schedule)
- Add comprehensive tests for both FSDP and Megatron backends

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces token-based micro-batching to the training pipeline, allowing micro-batches to be formed based on token counts rather than fixed sample sizes. This is achieved through new iterator classes and a bin-packing algorithm to balance token distribution across workers. The changes include updates to both the standard and Megatron workers to handle variable-sized micro-batches via padding and output reordering. Feedback was provided regarding a redundant check when calculating the maximum micro-batch size in the Megatron worker.


if use_token_batching:
# Pad microbatches to uniform batch size for Megatron compatibility
max_micro_bsz = max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1 is redundant if micro_dicts is guaranteed to be non-empty at this point. If it can be empty, consider handling it more explicitly to avoid potential issues with mbs being 1 when no data is present.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 7 additional findings in Devin Review.

Open in Devin Review

Comment on lines +288 to +308
"""Create a padding microbatch with loss_mask=0 so it doesn't affect the loss."""
seq_len = 2
num_actions = self.data.metadata["response_length"]
batch_size = 1

data = TrainingInputBatch(
{
"sequences": torch.randint(0, 100, (batch_size, seq_len), device="cpu"),
"attention_mask": torch.ones((batch_size, seq_len), dtype=int, device="cpu"),
"action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"),
"base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"),
"values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"),
"returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"),
"advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"),
# Loss mask is all zeros so padding samples don't contribute to the loss.
"loss_mask": torch.zeros((batch_size, num_actions), dtype=int, device="cpu"),
"response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"),
}
)
data.metadata = self.data.metadata
return data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Padding microbatches have seq_len=2, causing model crash when num_actions > 1

The _create_padding_microbatch method creates a dummy microbatch with seq_len = 2 but inherits the real response_length (num_actions) from self.data.metadata. When the model processes this padding microbatch, it produces logits of shape [1, 2, V] and tries to extract num_actions action log probs (e.g., token_logprobs[:, -4:] for num_actions=4 on a length-2 sequence). This produces a tensor of shape [1, min(1, num_actions)] instead of the expected [1, num_actions], causing shape mismatches in subsequent loss computations (e.g., action_log_probs - old_action_log_probs where shapes don't match).

This triggers whenever different DP workers end up with different numbers of microbatches after bin-packing (common with variable-length sequences), causing _num_padding_microbatches > 0 on some workers. Affected paths: Worker.forward(), PolicyWorkerBase.forward_backward(), CriticWorkerBase.forward_backward(), and both Megatron forward/forward_backward paths (where mixing seq_len=2 padding dicts with seq_len=N real dicts also breaks Megatron's uniform seq_length requirement in forward_backward_func).

Prompt for agents
The _create_padding_microbatch method at worker_utils.py:288-308 creates a padding microbatch with seq_len=2, which is incompatible with the model's expectations. The model needs sequence_length >= num_actions + 1 to extract the correct number of action log probabilities.

The fix should set seq_len in the padding microbatch to match the real data's sequence length (self.data['sequences'].shape[1]) instead of hardcoding 2. This ensures:
1. FSDP models can extract the correct number of action log probs
2. Megatron's forward_backward_func sees uniform seq_length across all microbatches
3. postprocess_packed_seqs / recover_left_padding don't fail on shape mismatches

Specifically, change line 289 from `seq_len = 2` to `seq_len = self.data['sequences'].shape[1]`. Also ensure the sequences tensor is appropriately sized (zeros or random tokens of the correct length), and that attention_mask, position_ids etc. have the correct seq_len dimension.

Also check _pad_forward_microbatch_to_size and _pad_microbatch_to_size in megatron_worker.py - they pad the batch dimension but assume all microbatches already share the same seq_len, which breaks when padding microbatches have seq_len=2.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +274 to +285
def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch:
"""Create a TrainingInputBatch from a list of sample indices."""
indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu")
selected_data = {}
for key, value in self.data.items():
if value is None:
selected_data[key] = None
else:
selected_data[key] = value[indices_tensor]
microbatch = TrainingInputBatch(selected_data)
microbatch.metadata = self.data.metadata
return microbatch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 _create_microbatch_from_indices fails with TensorList values (multi-modal data)

_create_microbatch_from_indices uses value[indices_tensor] where indices_tensor is a torch.Tensor. For regular torch.Tensor values this works (fancy indexing), but TrainingInputBatch can also contain TensorList values (used for pixel_values and image_grid_thw in multi-modal training). TensorList.__getitem__ at training_batch.py:71-74 only handles slice and int index types — a torch.Tensor index falls through to self.tensors[index] which raises TypeError because Python lists don't support tensor indexing. This means token-based batching (max_tokens_per_microbatch > 0) is broken for any multi-modal (vision-language) training run.

Suggested change
def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch:
"""Create a TrainingInputBatch from a list of sample indices."""
indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu")
selected_data = {}
for key, value in self.data.items():
if value is None:
selected_data[key] = None
else:
selected_data[key] = value[indices_tensor]
microbatch = TrainingInputBatch(selected_data)
microbatch.metadata = self.data.metadata
return microbatch
def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch:
"""Create a TrainingInputBatch from a list of sample indices."""
selected_data = {}
for key, value in self.data.items():
if value is None:
selected_data[key] = None
elif isinstance(value, TensorList):
selected_data[key] = TensorList([value[i] for i in indices])
else:
indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu")
selected_data[key] = value[indices_tensor]
microbatch = TrainingInputBatch(selected_data)
microbatch.metadata = self.data.metadata
return microbatch
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant