feat: add max_tokens_per_microbatch config for token-based micro-batching#1477
feat: add max_tokens_per_microbatch config for token-based micro-batching#1477erictang000 wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
…hing Adds a new `max_tokens_per_microbatch` config option that enables token-based micro-batching instead of fixed sample-count batching. When enabled, samples are bin-packed into microbatches based on their actual token counts (from attention_mask), which reduces padding waste when sequence lengths vary widely. Key changes: - Add `max_tokens_per_microbatch` config field (default -1 = disabled) - Add `TokenBasedBatchIterator` with balanced bin-packing algorithm - Add `SampleBasedBatchIterator` (refactored from old `BatchIterator`) - Update FSDP Worker.forward(), PolicyWorkerBase.forward_backward(), and CriticWorkerBase.forward_backward() to use token-based batching - Update Megatron forward and forward_backward paths with padding to uniform micro_batch_size (required by Megatron's pipeline schedule) - Add comprehensive tests for both FSDP and Megatron backends Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces token-based micro-batching to the training pipeline, allowing micro-batches to be formed based on token counts rather than fixed sample sizes. This is achieved through new iterator classes and a bin-packing algorithm to balance token distribution across workers. The changes include updates to both the standard and Megatron workers to handle variable-sized micro-batches via padding and output reordering. Feedback was provided regarding a redundant check when calculating the maximum micro-batch size in the Megatron worker.
|
|
||
| if use_token_batching: | ||
| # Pad microbatches to uniform batch size for Megatron compatibility | ||
| max_micro_bsz = max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1 |
There was a problem hiding this comment.
| """Create a padding microbatch with loss_mask=0 so it doesn't affect the loss.""" | ||
| seq_len = 2 | ||
| num_actions = self.data.metadata["response_length"] | ||
| batch_size = 1 | ||
|
|
||
| data = TrainingInputBatch( | ||
| { | ||
| "sequences": torch.randint(0, 100, (batch_size, seq_len), device="cpu"), | ||
| "attention_mask": torch.ones((batch_size, seq_len), dtype=int, device="cpu"), | ||
| "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), | ||
| "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), | ||
| "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), | ||
| "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), | ||
| "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), | ||
| # Loss mask is all zeros so padding samples don't contribute to the loss. | ||
| "loss_mask": torch.zeros((batch_size, num_actions), dtype=int, device="cpu"), | ||
| "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), | ||
| } | ||
| ) | ||
| data.metadata = self.data.metadata | ||
| return data |
There was a problem hiding this comment.
🔴 Padding microbatches have seq_len=2, causing model crash when num_actions > 1
The _create_padding_microbatch method creates a dummy microbatch with seq_len = 2 but inherits the real response_length (num_actions) from self.data.metadata. When the model processes this padding microbatch, it produces logits of shape [1, 2, V] and tries to extract num_actions action log probs (e.g., token_logprobs[:, -4:] for num_actions=4 on a length-2 sequence). This produces a tensor of shape [1, min(1, num_actions)] instead of the expected [1, num_actions], causing shape mismatches in subsequent loss computations (e.g., action_log_probs - old_action_log_probs where shapes don't match).
This triggers whenever different DP workers end up with different numbers of microbatches after bin-packing (common with variable-length sequences), causing _num_padding_microbatches > 0 on some workers. Affected paths: Worker.forward(), PolicyWorkerBase.forward_backward(), CriticWorkerBase.forward_backward(), and both Megatron forward/forward_backward paths (where mixing seq_len=2 padding dicts with seq_len=N real dicts also breaks Megatron's uniform seq_length requirement in forward_backward_func).
Prompt for agents
The _create_padding_microbatch method at worker_utils.py:288-308 creates a padding microbatch with seq_len=2, which is incompatible with the model's expectations. The model needs sequence_length >= num_actions + 1 to extract the correct number of action log probabilities.
The fix should set seq_len in the padding microbatch to match the real data's sequence length (self.data['sequences'].shape[1]) instead of hardcoding 2. This ensures:
1. FSDP models can extract the correct number of action log probs
2. Megatron's forward_backward_func sees uniform seq_length across all microbatches
3. postprocess_packed_seqs / recover_left_padding don't fail on shape mismatches
Specifically, change line 289 from `seq_len = 2` to `seq_len = self.data['sequences'].shape[1]`. Also ensure the sequences tensor is appropriately sized (zeros or random tokens of the correct length), and that attention_mask, position_ids etc. have the correct seq_len dimension.
Also check _pad_forward_microbatch_to_size and _pad_microbatch_to_size in megatron_worker.py - they pad the batch dimension but assume all microbatches already share the same seq_len, which breaks when padding microbatches have seq_len=2.
Was this helpful? React with 👍 or 👎 to provide feedback.
| def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: | ||
| """Create a TrainingInputBatch from a list of sample indices.""" | ||
| indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") | ||
| selected_data = {} | ||
| for key, value in self.data.items(): | ||
| if value is None: | ||
| selected_data[key] = None | ||
| else: | ||
| selected_data[key] = value[indices_tensor] | ||
| microbatch = TrainingInputBatch(selected_data) | ||
| microbatch.metadata = self.data.metadata | ||
| return microbatch |
There was a problem hiding this comment.
🟡 _create_microbatch_from_indices fails with TensorList values (multi-modal data)
_create_microbatch_from_indices uses value[indices_tensor] where indices_tensor is a torch.Tensor. For regular torch.Tensor values this works (fancy indexing), but TrainingInputBatch can also contain TensorList values (used for pixel_values and image_grid_thw in multi-modal training). TensorList.__getitem__ at training_batch.py:71-74 only handles slice and int index types — a torch.Tensor index falls through to self.tensors[index] which raises TypeError because Python lists don't support tensor indexing. This means token-based batching (max_tokens_per_microbatch > 0) is broken for any multi-modal (vision-language) training run.
| def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: | |
| """Create a TrainingInputBatch from a list of sample indices.""" | |
| indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") | |
| selected_data = {} | |
| for key, value in self.data.items(): | |
| if value is None: | |
| selected_data[key] = None | |
| else: | |
| selected_data[key] = value[indices_tensor] | |
| microbatch = TrainingInputBatch(selected_data) | |
| microbatch.metadata = self.data.metadata | |
| return microbatch | |
| def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: | |
| """Create a TrainingInputBatch from a list of sample indices.""" | |
| selected_data = {} | |
| for key, value in self.data.items(): | |
| if value is None: | |
| selected_data[key] = None | |
| elif isinstance(value, TensorList): | |
| selected_data[key] = TensorList([value[i] for i in indices]) | |
| else: | |
| indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") | |
| selected_data[key] = value[indices_tensor] | |
| microbatch = TrainingInputBatch(selected_data) | |
| microbatch.metadata = self.data.metadata | |
| return microbatch |
Was this helpful? React with 👍 or 👎 to provide feedback.
Adds a new
max_tokens_per_microbatchconfig option that enables token-based micro-batching instead of fixed sample-count batching. When enabled, samples are bin-packed into microbatches based on their actual token counts (from attention_mask), which reduces padding waste when sequence lengths vary widely.Key changes:
max_tokens_per_microbatchconfig field (default -1 = disabled)TokenBasedBatchIteratorwith balanced bin-packing algorithmSampleBasedBatchIterator(refactored from oldBatchIterator)