Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/gpu_skyrl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ concurrency:
jobs:
skyrl_gpu_tests:
runs-on: ubuntu-latest
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
defaults:
run:
shell: bash
Expand All @@ -47,10 +50,11 @@ jobs:
activate-environment: true
- name: Install dependencies
run: uv pip install anyscale==0.24.79 typer==0.9.0
- name: Skip GPU tests when Anyscale credentials are unavailable
if: ${{ env.ANYSCALE_CLI_TOKEN == '' }}
run: echo "Skipping GPU tests because ANYSCALE_CLI_TOKEN is unavailable in this workflow context."
- name: GPU tests
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
if: ${{ env.ANYSCALE_CLI_TOKEN != '' }}
run: |
anyscale job submit -f ci/anyscale_gpu_ci.yaml --timeout 10000
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-tx-gpu-ci --timeout 10000
6 changes: 4 additions & 2 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ algorithm:
# this adds training batch level normalization to advantages
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "clip_cov", "kl_cov" or customizable with PolicyLossRegistry
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "gmpo", "sapo", "clip_cov", "kl_cov", "cispo" or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO (used in Dr. GRPO)
zero_variance_filter: false # set to true to loss mask out prompts with zero variance rewards. only applicable when rewards are response-level.
Expand Down Expand Up @@ -485,15 +485,17 @@ algorithm:
- `regular`: Vanilla PPO loss with token-level importance sampling
- `dual_clip`: Dual clip PPO loss proposed in [this paper](https://arxiv.org/pdf/1912.09729)
- `gspo`: [Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071) with sequence-level importance sampling for improved training stability. Implements the "GSPO-token" variant from the paper.
- `gmpo`: [Geometric-Mean Policy Optimization](https://arxiv.org/abs/2507.20673), a sequence-level objective intended to be used with `loss_reduction="sequence_mean"` and typically paired with `advantage_estimator="grpo"`.
- `clip_cov`: Clip-Cov combines standard PPO clipping with covariance-based correction masking for improved stability. Based on [this paper](https://arxiv.org/abs/2505.22617).
- `kl_cov`: KL-Cov applies KL regularization to tokens selected based on covariance values. Based on [this paper](https://arxiv.org/abs/2505.22617).
- `cispo`: Clipped Importance Sampling Weight Policy Optimization (CISPO) proposed in [MiniMax-M1](https://arxiv.org/abs/2506.13585).
- `sapo`: Soft Adaptive Policy Optimization (SAPO) as proposed in [this paper](https://arxiv.org/pdf/2511.20347).
- Custom policy losses can be registered with the `PolicyLossRegistry`

- `algorithm.loss_reduction`: Type of loss reduction to use. Options include:

- `token_mean`: computes average loss over all valid tokens in the batch. Used in [DAPO](https://dapo-sia.github.io/).
- `sequence_mean`: computes per-sequence avg token loss, then averages over the batch.
- `sequence_mean`: computes per-sequence avg token loss, then averages over the batch. This is the required reduction mode for `gmpo`.
- `seq_mean_token_sum_norm`: computes the sum of token losses for each sequence, normalizes by `max_seq_len`, and then averages over the batch. This is used in [Dr. GRPO](https://arxiv.org/abs/2503.20783). If `algorithm.max_seq_len` is not explicitly set, it defaults to `generator.max_input_length + generator.sampling_params.max_generate_length`.

- `algorithm.grpo_norm_by_std`: Whether to normalize advantages by the standard deviation in GRPO. This is set to `false` in [Dr. GRPO](https://arxiv.org/abs/2503.20783).
Expand Down
2 changes: 1 addition & 1 deletion docs/content/docs/tinker/architecture.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ The following loss functions are validated through the Tinker API:
| `cross_entropy` | Standard next-token prediction loss | Supervised fine-tuning |
| `importance_sampling` | Off-policy policy gradient: `-(exp(logp - old_logp) * advantage)` | RL training (GRPO, REINFORCE) |

SkyRL-Train's `PolicyLossRegistry` also contains additional loss functions (`regular`, `dual_clip`, `gspo`, `sapo`, `cispo`, `clip_cov`, `kl_cov`) used by SkyRL's native trainer. These are not yet wired through the Tinker data conversion path, which does not currently populate the required `advantages` and `old_log_probs` fields in the training batch for these loss types.
SkyRL-Train's `PolicyLossRegistry` also contains additional loss functions (`regular`, `dual_clip`, `gspo`, `gmpo`, `sapo`, `cispo`, `clip_cov`, `kl_cov`) used by SkyRL's native trainer. These are not yet wired through the Tinker data conversion path, which does not currently populate the required `advantages` and `old_log_probs` fields in the training batch for these loss types.

## Concurrency Model

Expand Down
3 changes: 1 addition & 2 deletions docs/content/docs/tinker/limitations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,4 @@ KL penalty (`kl_penalty_coef > 0`) is not yet supported. This requires prompt lo

### RL Loss Functions

Only `cross_entropy` and `importance_sampling` are currently wired through the Tinker data conversion path. SkyRL's `PolicyLossRegistry` contains implementations for PPO (`regular`), `cispo`, and others, but these are not yet validated through the Tinker API.

Only `cross_entropy` and `importance_sampling` are currently wired through the Tinker data conversion path. SkyRL's `PolicyLossRegistry` contains implementations for PPO (`regular`), `gmpo`, `cispo`, and others, but these are not yet validated through the Tinker API.
1 change: 1 addition & 0 deletions examples/train/sft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ The `loss_fn` parameter supports:
| `cross_entropy` | Supervised fine-tuning |
| `regular` / `ppo` | PPO with clipping |
| `gspo` | Group Sequence Policy Optimization |
| `gmpo` | Geometric-Mean Policy Optimization |
| ... | See `PolicyLossRegistry` for all options |
65 changes: 65 additions & 0 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ class PolicyLossType(StrEnum):
REGULAR = "regular"
DUAL_CLIP = "dual_clip"
GSPO = "gspo"
GMPO = "gmpo"
CISPO = "cispo"
ROLLOUT_IS = "rollout_is"
CLIP_COV = "clip_cov"
Expand Down Expand Up @@ -477,6 +478,7 @@ def repopulate_registry(cls):
"regular": [PolicyLossType.REGULAR, ppo_policy_loss],
"dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
"gspo": [PolicyLossType.GSPO, gspo_policy_loss],
"gmpo": [PolicyLossType.GMPO, gmpo_policy_loss],
"clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
"kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
Expand Down Expand Up @@ -704,6 +706,69 @@ def gspo_policy_loss(
return loss, loss_metrics


@register_policy_loss(PolicyLossType.GMPO)
def gmpo_policy_loss(
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
config: AlgorithmConfig,
loss_mask: Optional[torch.Tensor] = None,
rollout_logprobs: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict[str, float]]:
"""
GMPO (Geometric-Mean Policy Optimization) policy loss function.

GMPO is a sequence-level objective. In SkyRL, ``sequence_mean`` reduction is
implemented by pre-scaling token advantages in the trainer. Summing the masked
advantages here therefore recovers one sequence weight per response, which lets
us compute the GMPO objective without changing the trainer's reduction pipeline.
"""
if config.loss_reduction != "sequence_mean":
raise ValueError(
'GMPO requires `trainer.algorithm.loss_reduction="sequence_mean"` because it is a sequence-level objective.'
)

off_policy_correction = config.off_policy_correction
uses_off_policy_correction = (
off_policy_correction.tis_ratio_type is not None
or off_policy_correction.sequence_mask_metric is not None
or off_policy_correction.outlier_token_is_threshold_low is not None
or off_policy_correction.outlier_token_is_threshold_high is not None
or off_policy_correction.token_mask_is_threshold_low is not None
or off_policy_correction.token_mask_is_threshold_high is not None
)
if uses_off_policy_correction:
raise NotImplementedError("GMPO does not support `trainer.algorithm.off_policy_correction` yet.")

loss_mask = torch.ones_like(log_probs) if loss_mask is None else loss_mask.to(log_probs.dtype)

log_ratio = log_probs - old_log_probs
clipped_log_ratio = torch.clamp(log_ratio, -config.eps_clip_low, config.eps_clip_high)

sign = torch.sign(advantages)
effective_log_ratio = sign * torch.minimum(sign * log_ratio, sign * clipped_log_ratio)

seq_len = loss_mask.sum(dim=-1).clamp(min=1.0)
seq_log_ratio = (effective_log_ratio * loss_mask).sum(dim=-1) / seq_len
seq_ratio = safe_exp_delta(seq_log_ratio, clip=20.0, out_dtype=log_probs.dtype)

# For sequence_mean reduction, advantages have already been pre-scaled by the
# trainer. Summing across valid tokens recovers one sequence weight per sample.
seq_advantage = (advantages * loss_mask).sum(dim=-1)
loss = (-seq_advantage * seq_ratio).sum()

clipped = torch.ne(log_ratio, clipped_log_ratio)
clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()
Comment on lines +760 to +762
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The clip_ratio and clip_ratio_lower metrics are slightly inaccurate as currently implemented. clipped is defined as any log-ratio outside the [-eps_clip_low, eps_clip_high] range, but in PPO-style clipping (which GMPO follows here), a token is only considered "clipped" if the clipping actually modifies the objective. For example, if advantages > 0 and log_ratio < -eps_clip_low, the objective uses the original log_ratio (via the minimum logic), so it isn't actually clipped.

A more accurate way to compute these metrics is to check where effective_log_ratio differs from the original log_ratio.

Suggested change
clipped = torch.ne(log_ratio, clipped_log_ratio)
clip_ratio = masked_mean((clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()
is_clipped = torch.ne(effective_log_ratio, log_ratio)
clip_ratio = masked_mean((is_clipped & (advantages > 0)).float(), loss_mask).mean().detach().item()
clip_ratio_lower = masked_mean((is_clipped & (advantages < 0)).float(), loss_mask).mean().detach().item()

seq_ratio_mean = seq_ratio.mean().detach().item()

return loss, {
"clip_ratio": clip_ratio,
"clip_ratio_lower": clip_ratio_lower,
"seq_ratio_mean": seq_ratio_mean,
}


@register_policy_loss(PolicyLossType.CISPO)
def compute_policy_loss_cispo(
log_probs: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class AlgorithmConfig(BaseConfig):
advantage_batch_normalize: bool = False
value_head_prefix: str = "value_head"
policy_loss_type: str = "regular"
"""``"regular"``, ``"dual_clip"``, ``"gspo"``, ``"clip_cov"``, ``"kl_cov"``, or custom via ``PolicyLossRegistry``."""
"""``"regular"``, ``"dual_clip"``, ``"gspo"``, ``"gmpo"``, ``"sapo"``, ``"clip_cov"``, ``"kl_cov"``, ``"cispo"``, or custom via ``PolicyLossRegistry``."""
loss_reduction: str = "token_mean"
"""``"token_mean"``, ``"sequence_mean"``, or ``"seq_mean_token_sum_norm"``."""
grpo_norm_by_std: bool = True
Expand Down
4 changes: 2 additions & 2 deletions skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ trainer:
# this adds training batch level normalization to advantages
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "clip_cov", "kl_cov", or customizable with PolicyLossRegistry
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "gmpo", "sapo", "clip_cov", "kl_cov", "cispo", or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO
zero_variance_filter: false # set to true to loss mask out prompts with zero variance rewards. only applicable when rewards are response-level.
Expand Down Expand Up @@ -385,4 +385,4 @@ generator:
environment:
env_class: "gsm8k"
# NOTE: environment specific defaults for environment.skyrl_gym are set at the following path:
# skyrl_gym: config/skyrl_gym_config/default.yaml
# skyrl_gym: config/skyrl_gym_config/default.yaml
19 changes: 19 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,31 @@ def validate_cfg(cfg: SkyRLTrainConfig):
cfg.trainer.algorithm.off_policy_correction.tis_ratio_type = "token"
cfg.trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high = cfg.trainer.algorithm.tis_imp_ratio_cap

if cfg.trainer.algorithm.policy_loss_type == "gmpo":
assert (
cfg.trainer.algorithm.loss_reduction == "sequence_mean"
), "GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`."
Comment on lines +296 to +299
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the check inside gmpo_policy_loss (and the NotImplementedError raised just below), it's better to raise a ValueError here instead of using an assert. AssertionError is typically reserved for internal logic bugs rather than user configuration errors.

    if cfg.trainer.algorithm.policy_loss_type == "gmpo" and cfg.trainer.algorithm.loss_reduction != "sequence_mean":
        raise ValueError("GMPO requires `trainer.algorithm.loss_reduction` to be `sequence_mean`.")


# off_policy_correction config validation
off_policy_correction = cfg.trainer.algorithm.off_policy_correction
tis_ratio_type = off_policy_correction.tis_ratio_type
sequence_mask_metric = off_policy_correction.sequence_mask_metric

uses_off_policy_correction = tis_ratio_type is not None or sequence_mask_metric is not None
gmpo_uses_off_policy_correction = uses_off_policy_correction or any(
threshold is not None
for threshold in (
off_policy_correction.outlier_token_is_threshold_low,
off_policy_correction.outlier_token_is_threshold_high,
off_policy_correction.token_mask_is_threshold_low,
off_policy_correction.token_mask_is_threshold_high,
)
)

if cfg.trainer.algorithm.policy_loss_type == "gmpo" and gmpo_uses_off_policy_correction:
raise NotImplementedError(
"GMPO does not support `trainer.algorithm.off_policy_correction`; use `sequence_mean` without rollout correction."
)

if uses_off_policy_correction:
# Validate tis_ratio_type
Expand Down
7 changes: 7 additions & 0 deletions tests/backends/skyrl_train/utils/test_ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def decorated_policy_loss(log_probs, old_log_probs, advantages, config, loss_mas
PolicyLossRegistry.unregister("test_policy_decorator")


def test_builtin_gmpo_policy_loss_registered():
"""GMPO should be available as a built-in policy loss."""

assert "gmpo" in PolicyLossRegistry.list_available()
assert callable(PolicyLossRegistry.get("gmpo"))


def test_registry_cross_ray_process():
"""Test that registry works with Ray and that functions can be retrieved and called from different processes"""
try:
Expand Down
132 changes: 132 additions & 0 deletions tests/train/algorithms/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,138 @@ def test_gspo_importance_sampling_levels():
), f"GSPO should have uniform importance weights within sequence {seq_idx}"


def test_gmpo_policy_loss_sequence_level_objective():
"""GMPO should compute a sequence-level clipped geometric-mean objective."""

device = "cpu"

# These advantages are already scaled for sequence_mean reduction:
# raw sequence advantages would be [2.0, -4.0] with batch_size=2 and seq_len=2.
advantages = torch.tensor(
[
[0.5, 0.5, 0.0],
[-1.0, -1.0, 0.0],
],
device=device,
)
old_log_probs = torch.full_like(advantages, -1.0)
log_probs = torch.tensor(
[
[-0.2, -1.4, 5.0],
[-1.6, -0.3, 5.0],
],
device=device,
)
loss_mask = torch.tensor(
[
[1.0, 1.0, 0.0],
[1.0, 1.0, 0.0],
],
device=device,
)

config = AlgorithmConfig(
eps_clip_low=0.2,
eps_clip_high=0.2,
policy_loss_type="gmpo",
loss_reduction="sequence_mean",
max_seq_len=4,
off_policy_correction=NULL_OFF_POLICY_CORR,
)

loss_fn = PolicyLossRegistry.get("gmpo")
actual_loss, loss_metrics = loss_fn(log_probs, old_log_probs, advantages, config, loss_mask)

log_ratio = log_probs - old_log_probs
clipped_log_ratio = torch.clamp(log_ratio, -0.2, 0.2)
sign = torch.sign(advantages)
effective_log_ratio = sign * torch.minimum(sign * log_ratio, sign * clipped_log_ratio)
seq_len = loss_mask.sum(dim=-1).clamp(min=1.0)
expected_seq_ratio = torch.exp((effective_log_ratio * loss_mask).sum(dim=-1) / seq_len)
expected_seq_advantage = (advantages * loss_mask).sum(dim=-1)
expected_loss = (-expected_seq_advantage * expected_seq_ratio).sum()

torch.testing.assert_close(actual_loss, expected_loss, rtol=1e-5, atol=1e-8)
assert loss_metrics["clip_ratio"] == pytest.approx(0.5, abs=1e-6)
assert loss_metrics["clip_ratio_lower"] == pytest.approx(0.5, abs=1e-6)
assert loss_metrics["seq_ratio_mean"] == pytest.approx(expected_seq_ratio.mean().item(), abs=1e-6)


def test_gmpo_policy_loss_all_masked_sequence_is_finite():
"""GMPO should handle sequences with zero valid tokens without NaNs."""

device = "cpu"
advantages = torch.tensor([[0.5, 0.5], [-1.0, -1.0]], device=device)
old_log_probs = torch.full_like(advantages, -1.0)
log_probs = torch.tensor([[-0.8, -1.2], [-0.4, -1.6]], device=device)
loss_mask = torch.tensor([[1.0, 1.0], [0.0, 0.0]], device=device)

config = AlgorithmConfig(
eps_clip_low=0.2,
eps_clip_high=0.2,
policy_loss_type="gmpo",
loss_reduction="sequence_mean",
max_seq_len=4,
off_policy_correction=NULL_OFF_POLICY_CORR,
)

loss_fn = PolicyLossRegistry.get("gmpo")
loss, loss_metrics = loss_fn(log_probs, old_log_probs, advantages, config, loss_mask)

assert torch.isfinite(loss)
assert torch.isfinite(torch.tensor(loss_metrics["seq_ratio_mean"]))


def test_gmpo_policy_loss_differs_from_regular_ppo():
"""GMPO should behave differently from tokenwise PPO on high-variance sequences."""

device = "cpu"
advantages = torch.tensor(
[
[0.5, 0.5, 0.0],
[0.5, 0.5, 0.0],
],
device=device,
)
old_log_probs = torch.full_like(advantages, -1.0)
log_probs = torch.tensor(
[
[-0.2, -2.0, -1.0],
[-2.2, 0.4, -1.0],
],
device=device,
)
loss_mask = torch.tensor(
[
[1.0, 1.0, 0.0],
[1.0, 1.0, 0.0],
],
device=device,
)

gmpo_config = AlgorithmConfig(
eps_clip_low=0.2,
eps_clip_high=0.2,
policy_loss_type="gmpo",
loss_reduction="sequence_mean",
max_seq_len=4,
off_policy_correction=NULL_OFF_POLICY_CORR,
)
regular_config = AlgorithmConfig(
eps_clip_low=0.2,
eps_clip_high=0.2,
policy_loss_type="regular",
loss_reduction="sequence_mean",
max_seq_len=4,
off_policy_correction=NULL_OFF_POLICY_CORR,
)

gmpo_loss, _ = PolicyLossRegistry.get("gmpo")(log_probs, old_log_probs, advantages, gmpo_config, loss_mask)
regular_loss, _ = PolicyLossRegistry.get("regular")(log_probs, old_log_probs, advantages, regular_config, loss_mask)

assert not torch.allclose(gmpo_loss, regular_loss, rtol=1e-3)


def test_clip_cov_policy_loss():
"""Tests Clip-Cov policy loss function with covariance-based correction."""

Expand Down
Loading
Loading