Skip to content

[skyrl-train] feat: add native GMPO policy loss with validation and tests#1449

Open
taivu1998 wants to merge 3 commits intoNovaSky-AI:mainfrom
taivu1998:tdv/issue-834-gmpo
Open

[skyrl-train] feat: add native GMPO policy loss with validation and tests#1449
taivu1998 wants to merge 3 commits intoNovaSky-AI:mainfrom
taivu1998:tdv/issue-834-gmpo

Conversation

@taivu1998
Copy link
Copy Markdown

@taivu1998 taivu1998 commented Apr 2, 2026

Summary

Related to #834.

This draft PR adds native GMPO support to skyrl-train as a built-in policy_loss_type.

The implementation is intentionally narrow and trainer-native:

  • adds gmpo to PolicyLossRegistry
  • implements a sequence-level gmpo_policy_loss(...) in ppo_utils.py
  • enforces loss_reduction="sequence_mean" for GMPO
  • explicitly rejects off_policy_correction for GMPO in v1
  • updates config/docs surfaces that enumerate built-in policy losses
  • adds unit coverage for loss math, registry availability, and config validation

What Changed

Core implementation

  • added PolicyLossType.GMPO
  • registered gmpo in the built-in policy loss registry
  • implemented the GMPO objective in skyrl/backends/skyrl_train/utils/ppo_utils.py

Validation

  • added validate_cfg(...) guards for unsupported GMPO combinations:
    • loss_reduction must be sequence_mean
    • off_policy_correction is not supported with GMPO yet

Tests

  • added GMPO math tests in tests/train/algorithms/test_losses.py
  • added built-in registry coverage in tests/backends/skyrl_train/utils/test_ppo_utils.py
  • added config validation coverage in tests/train/test_trainer.py

Docs

  • updated configuration and trainer-loss references to include gmpo
  • clarified that GMPO is a native trainer capability, not a Tinker-integrated loss path

Design Notes

This branch of SkyRL applies loss-reduction semantics by pre-scaling advantages in the trainer and then summing policy losses in the worker path. The GMPO implementation is written to match that existing contract rather than introducing a separate reduction path.

That keeps the change small and avoids broader trainer/runtime edits.

Verification

Completed locally:

  • python3 -m py_compile skyrl/backends/skyrl_train/utils/ppo_utils.py skyrl/train/utils/utils.py skyrl/train/config/config.py tests/train/algorithms/test_losses.py tests/backends/skyrl_train/utils/test_ppo_utils.py tests/train/test_trainer.py
  • git diff --check
  • direct source-level smoke check for gmpo_policy_loss(...) covering:
    • loss math on a deterministic fixture
    • invalid reduction rejection
    • off-policy-correction rejection

Not completed locally:

  • full pytest target set

Reason:

  • uv run was blocked in this environment by offline dependency fetches for flash-attn
  • the available fallback system pytest environment does not include ray

Follow-ups

Potential follow-up work, not included here:

  • optional geo_mean alias for verl naming compatibility
  • a dedicated GMPO example script
  • sequence-aware off-policy correction for GMPO

Open with Devin

@taivu1998 taivu1998 closed this Apr 3, 2026
@taivu1998 taivu1998 reopened this Apr 4, 2026
@taivu1998 taivu1998 marked this pull request as ready for review April 5, 2026 00:07
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Geometric-Mean Policy Optimization (GMPO) as a new sequence-level policy loss function and updates documentation for several other loss types. The GMPO implementation requires sequence_mean reduction and does not yet support off-policy corrections. Feedback suggests refining the clip_ratio metrics to accurately detect when clipping modifies the objective and using ValueError instead of assert for configuration validation to improve error handling.

Comment on lines +760 to +762
clipped = torch.ne(log_ratio, clipped_log_ratio)
clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The clip_ratio and clip_ratio_lower metrics are slightly inaccurate as currently implemented. clipped is defined as any log-ratio outside the [-eps_clip_low, eps_clip_high] range, but in PPO-style clipping (which GMPO follows here), a token is only considered "clipped" if the clipping actually modifies the objective. For example, if advantages > 0 and log_ratio < -eps_clip_low, the objective uses the original log_ratio (via the minimum logic), so it isn't actually clipped.

A more accurate way to compute these metrics is to check where effective_log_ratio differs from the original log_ratio.

Suggested change
clipped = torch.ne(log_ratio, clipped_log_ratio)
clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()
is_clipped = torch.ne(effective_log_ratio, log_ratio)
clip_ratio = masked_mean((is_clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((is_clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()

Comment on lines +296 to +299
if cfg.trainer.algorithm.policy_loss_type == "gmpo":
assert (
cfg.trainer.algorithm.loss_reduction == "sequence_mean"
), "GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the check inside gmpo_policy_loss (and the NotImplementedError raised just below), it's better to raise a ValueError here instead of using an assert. AssertionError is typically reserved for internal logic bugs rather than user configuration errors.

    if cfg.trainer.algorithm.policy_loss_type == "gmpo" and cfg.trainer.algorithm.loss_reduction != "sequence_mean":
        raise ValueError("GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`.")

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant