Skip to content

refactor: use jnp.take_along_axis for return_last_only indexing#588

Open
stanley1208 wants to merge 1 commit intogoogle-deepmind:mainfrom
stanley1208:refactor/use-take-along-axis
Open

refactor: use jnp.take_along_axis for return_last_only indexing#588
stanley1208 wants to merge 1 commit intogoogle-deepmind:mainfrom
stanley1208:refactor/use-take-along-axis

Conversation

@stanley1208
Copy link

Summary

Resolves TODO(epot): Use jnp.take_along_axis in both Transformer.__call__ and Gemma3nTransformer.__call__.

The return_last_only code path previously used manual fancy indexing with jnp.arange to select the last non-padded token per batch element. This replaces it with jnp.take_along_axis, which is cleaner and avoids constructing an index array.

Changes

  • gemma/gm/nn/_transformer.py (line ~253): replaced x[jnp.arange(len(x)), idx, ...] with jnp.take_along_axis(x, idx[:, None, None], axis=1) + squeeze
  • gemma/gm/nn/gemma3n/_transformer.py (line ~323): same change

Test plan

  • test_prefill passes (exercises return_last_only=True via the standard Transformer)
  • test_sampler passes (end-to-end sampling using prefill)
  • test_last_only passes (directly tests return_last_only on Gemma3n)
  • All 9 relevant tests pass

Replace manual fancy indexing with jnp.take_along_axis in Transformer and Gemma3nTransformer, as requested by TODO(epot). Cleaner and avoids constructing an arange index array.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant