Skip to content

Expose n_ctx override in HookedTransformer.from_pretrained (issue #1006)#1204

Merged
jlarson4 merged 5 commits intoTransformerLensOrg:devfrom
brainsnog:feat/n-ctx-override
Mar 16, 2026
Merged

Expose n_ctx override in HookedTransformer.from_pretrained (issue #1006)#1204
jlarson4 merged 5 commits intoTransformerLensOrg:devfrom
brainsnog:feat/n-ctx-override

Conversation

@brainsnog
Copy link

Summary

Fixes #1006.

get_pretrained_model_config already accepted n_ctx as a parameter,
but it was invisible to users — silently falling through **kwargs with
no type hint, no IDE autocomplete, and no docstring entry on
from_pretrained. Users had to edit source files directly to change
context length.

This PR surfaces it as an explicit, documented parameter.

Changes

transformer_lens/HookedTransformer.py

  • Added n_ctx: Optional[int] = None to from_pretrained signature
  • Added n_ctx=n_ctx to the explicit call to get_pretrained_model_config
    (previously relied on silent kwargs passthrough)

transformer_lens/loading_from_pretrained.py

  • Extended the existing n_ctx override block to emit a logging.warning
    when the requested value exceeds the model's trained default, informing
    users of potential reliability and memory implications

tests/unit/test_loading_from_pretrained_utilities.py

  • Added test_n_ctx_override_reduces_context — verifies override works
    when reducing below default
  • Added test_n_ctx_override_larger_than_default_warns — verifies warning
    fires with correct message when exceeding default

Both tests use GPT-2 Small and run on CPU — no GPU required.

Usage

# Reduce context for memory efficiency
model = HookedTransformer.from_pretrained("gpt2", n_ctx=256)

# Extend context beyond conservative default (e.g. LLaMA 3.2 defaults to 2048)
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.2-3B", n_ctx=32768)

…nsformerLensOrg#1006)

get_pretrained_model_config already supported n_ctx as a parameter but
it was invisible to users — falling through **kwargs with no type hint,
no docstring entry, and no IDE support.

This commit surfaces it as an explicit parameter on from_pretrained,
adds a warning when n_ctx exceeds the model's trained default, and
adds two regression tests using GPT-2 (runs on CPU, no GPU required).

Users can now do:
    model = HookedTransformer.from_pretrained('gpt2', n_ctx=256)
    model = HookedTransformer.from_pretrained('meta-llama/Llama-3.2-3B', n_ctx=32768)

Fixes TransformerLensOrg#1006
@brainsnog brainsnog marked this pull request as ready for review March 15, 2026 16:32
@jlarson4 jlarson4 changed the base branch from main to dev March 16, 2026 13:12
@jlarson4 jlarson4 merged commit 0199ef8 into TransformerLensOrg:dev Mar 16, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Allow overriding config.n_ctx at model initialization

2 participants