ppo: add entropy cost annealing (linear + cosine schedules)#660
Open
PhysicistJohn wants to merge 1 commit intogoogle:mainfrom
Open
ppo: add entropy cost annealing (linear + cosine schedules)#660PhysicistJohn wants to merge 1 commit intogoogle:mainfrom
PhysicistJohn wants to merge 1 commit intogoogle:mainfrom
Conversation
Adds two new parameters to `ppo.train()`:
entropy_cost_end: Optional[float] = None
entropy_schedule: str = 'linear' # or 'cosine'
When `entropy_cost_end` is set, the entropy coefficient decays from
`entropy_cost` (start) to `entropy_cost_end` over the full training
budget, following the selected schedule. Default is unchanged —
`entropy_cost_end=None` keeps the existing constant-cost behaviour.
The schedule is computed inside the existing JIT graph using
`training_state.env_steps`, so there is no per-epoch recompilation.
The `if/else` on `entropy_schedule` is a Python-level (static) branch,
so JAX traces only the selected path. `entropy_cost` is now threaded
as a keyword argument through `training_step → sgd_step →
minibatch_step → loss_and_pgrad_fn`, which is already a pure
`*args, **kwargs` pass-through, so no signature conflicts arise.
The current coefficient is logged to TensorBoard as `training/entropy_cost`.
Motivation: a fixed high entropy cost promotes exploration early in
training but prevents the policy from committing to precise actions
later. Annealing entropy over the run allows warm exploration followed
by policy consolidation without requiring separate training phases.
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds entropy cost annealing to Brax PPO via two new parameters on
ppo.train():When
entropy_cost_endis set the coefficient decays fromentropy_costtoentropy_cost_endover the training budget. Default is unchanged — backwards-compatible with all existing code.Design notes
training_state.env_steps— no per-epoch recompilation.entropy_scheduleis a static Python string; JAX traces only the selected branch.entropy_costis threaded as a keyword arg through the existing call chain (training_step → sgd_step → minibatch_step → loss_and_pgrad_fn).loss_and_pgradis already a pure*args, **kwargspass-through so no signature conflicts arise.training/entropy_cost.Motivation
A fixed high entropy cost promotes exploration early in training but prevents the policy from committing to precise actions later. Annealing allows warm exploration followed by policy consolidation without separate training phases — particularly useful for long runs (1B+ steps) where behaviour changes significantly over time.