[skyrl-train] feat: add native GMPO policy loss with validation and tests#1449
[skyrl-train] feat: add native GMPO policy loss with validation and tests#1449taivu1998 wants to merge 3 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Geometric-Mean Policy Optimization (GMPO) as a new sequence-level policy loss function and updates documentation for several other loss types. The GMPO implementation requires sequence_mean reduction and does not yet support off-policy corrections. Feedback suggests refining the clip_ratio metrics to accurately detect when clipping modifies the objective and using ValueError instead of assert for configuration validation to improve error handling.
| clipped = torch.ne(log_ratio, clipped_log_ratio) | ||
| clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item() | ||
| clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item() |
There was a problem hiding this comment.
The clip_ratio and clip_ratio_lower metrics are slightly inaccurate as currently implemented. clipped is defined as any log-ratio outside the [-eps_clip_low, eps_clip_high] range, but in PPO-style clipping (which GMPO follows here), a token is only considered "clipped" if the clipping actually modifies the objective. For example, if advantages > 0 and log_ratio < -eps_clip_low, the objective uses the original log_ratio (via the minimum logic), so it isn't actually clipped.
A more accurate way to compute these metrics is to check where effective_log_ratio differs from the original log_ratio.
| clipped = torch.ne(log_ratio, clipped_log_ratio) | |
| clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item() | |
| clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item() | |
| is_clipped = torch.ne(effective_log_ratio, log_ratio) | |
| clip_ratio = masked_mean((is_clipped & (advantages > 0)).float(), loss_mask).mean().detach().item() | |
| clip_ratio_lower = masked_mean((is_clipped & (advantages < 0)).float(), loss_mask).mean().detach().item() |
| if cfg.trainer.algorithm.policy_loss_type == "gmpo": | ||
| assert ( | ||
| cfg.trainer.algorithm.loss_reduction == "sequence_mean" | ||
| ), "GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`." |
There was a problem hiding this comment.
For consistency with the check inside gmpo_policy_loss (and the NotImplementedError raised just below), it's better to raise a ValueError here instead of using an assert. AssertionError is typically reserved for internal logic bugs rather than user configuration errors.
if cfg.trainer.algorithm.policy_loss_type == "gmpo" and cfg.trainer.algorithm.loss_reduction != "sequence_mean":
raise ValueError("GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`.")
Summary
Related to #834.
This draft PR adds native GMPO support to
skyrl-trainas a built-inpolicy_loss_type.The implementation is intentionally narrow and trainer-native:
gmpotoPolicyLossRegistrygmpo_policy_loss(...)inppo_utils.pyloss_reduction="sequence_mean"for GMPOoff_policy_correctionfor GMPO in v1What Changed
Core implementation
PolicyLossType.GMPOgmpoin the built-in policy loss registryskyrl/backends/skyrl_train/utils/ppo_utils.pyValidation
validate_cfg(...)guards for unsupported GMPO combinations:loss_reductionmust besequence_meanoff_policy_correctionis not supported with GMPO yetTests
tests/train/algorithms/test_losses.pytests/backends/skyrl_train/utils/test_ppo_utils.pytests/train/test_trainer.pyDocs
gmpoDesign Notes
This branch of SkyRL applies loss-reduction semantics by pre-scaling advantages in the trainer and then summing policy losses in the worker path. The GMPO implementation is written to match that existing contract rather than introducing a separate reduction path.
That keeps the change small and avoids broader trainer/runtime edits.
Verification
Completed locally:
python3 -m py_compile skyrl/backends/skyrl_train/utils/ppo_utils.py skyrl/train/utils/utils.py skyrl/train/config/config.py tests/train/algorithms/test_losses.py tests/backends/skyrl_train/utils/test_ppo_utils.py tests/train/test_trainer.pygit diff --checkgmpo_policy_loss(...)covering:Not completed locally:
pytesttarget setReason:
uv runwas blocked in this environment by offline dependency fetches forflash-attnpytestenvironment does not includerayFollow-ups
Potential follow-up work, not included here:
geo_meanalias for verl naming compatibility