Skip to content

ppo: add entropy cost annealing (linear + cosine schedules)#660

Open
PhysicistJohn wants to merge 1 commit intogoogle:mainfrom
PhysicistJohn:feature/ppo-entropy-annealing
Open

ppo: add entropy cost annealing (linear + cosine schedules)#660
PhysicistJohn wants to merge 1 commit intogoogle:mainfrom
PhysicistJohn:feature/ppo-entropy-annealing

Conversation

@PhysicistJohn
Copy link

Summary

Adds entropy cost annealing to Brax PPO via two new parameters on ppo.train():

entropy_cost_end: Optional[float] = None
entropy_schedule: str = 'linear'   # or 'cosine'

When entropy_cost_end is set the coefficient decays from entropy_cost to entropy_cost_end over the training budget. Default is unchanged — backwards-compatible with all existing code.

Design notes

  • Schedule is computed inside the existing JIT graph using training_state.env_steps — no per-epoch recompilation.
  • entropy_schedule is a static Python string; JAX traces only the selected branch.
  • entropy_cost is threaded as a keyword arg through the existing call chain (training_step → sgd_step → minibatch_step → loss_and_pgrad_fn). loss_and_pgrad is already a pure *args, **kwargs pass-through so no signature conflicts arise.
  • Current coefficient is logged to TensorBoard as training/entropy_cost.

Motivation

A fixed high entropy cost promotes exploration early in training but prevents the policy from committing to precise actions later. Annealing allows warm exploration followed by policy consolidation without separate training phases — particularly useful for long runs (1B+ steps) where behaviour changes significantly over time.

Adds two new parameters to `ppo.train()`:

    entropy_cost_end: Optional[float] = None
    entropy_schedule: str = 'linear'   # or 'cosine'

When `entropy_cost_end` is set, the entropy coefficient decays from
`entropy_cost` (start) to `entropy_cost_end` over the full training
budget, following the selected schedule.  Default is unchanged —
`entropy_cost_end=None` keeps the existing constant-cost behaviour.

The schedule is computed inside the existing JIT graph using
`training_state.env_steps`, so there is no per-epoch recompilation.

The `if/else` on `entropy_schedule` is a Python-level (static) branch,
so JAX traces only the selected path.  `entropy_cost` is now threaded
as a keyword argument through `training_step → sgd_step →
minibatch_step → loss_and_pgrad_fn`, which is already a pure
`*args, **kwargs` pass-through, so no signature conflicts arise.

The current coefficient is logged to TensorBoard as `training/entropy_cost`.

Motivation: a fixed high entropy cost promotes exploration early in
training but prevents the policy from committing to precise actions
later.  Annealing entropy over the run allows warm exploration followed
by policy consolidation without requiring separate training phases.
@google-cla
Copy link

google-cla bot commented Feb 26, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@PhysicistJohn PhysicistJohn marked this pull request as ready for review February 26, 2026 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant