Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 85 additions & 20 deletions renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,17 @@ def _model_has_vision_config(model_name: str) -> bool:
}


# Tokenizer repos to use when a canonical model repo is gated but an
# audited unrestricted mirror ships byte-identical tokenizer files and
# chat_template. The returned tokenizer keeps the caller's original
# ``name_or_path`` so exact-match renderer resolution still uses
# ``MODEL_RENDERER_MAP``.
TOKENIZER_SOURCE_OVERRIDES: dict[str, str] = {
"meta-llama/Llama-3.2-1B-Instruct": "unsloth/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct": "unsloth/Llama-3.2-3B-Instruct",
}


# Models for which ``fastokens`` is known to diverge from vanilla
# ``transformers.AutoTokenizer`` and therefore must NOT be patched.
# Empirical audit ran each entry of ``MODEL_RENDERER_MAP`` through both
Expand All @@ -1175,6 +1186,42 @@ def _model_has_vision_config(model_name: str) -> bool:
_FASTOKENS_ANNOUNCED = False


def _tokenizer_source_for(model_name_or_path: str) -> str:
return TOKENIZER_SOURCE_OVERRIDES.get(model_name_or_path, model_name_or_path)


def _tokenizer_load_kwargs(model_name_or_path: str) -> dict[str, Any]:
revision = TRUSTED_REVISIONS.get(model_name_or_path)
if revision is not None:
return {"trust_remote_code": True, "revision": revision}
return {"trust_remote_code": False}


def _preserve_requested_tokenizer_name(
tokenizer,
*,
requested_name_or_path: str,
loaded_name_or_path: str,
):
if requested_name_or_path == loaded_name_or_path:
return tokenizer

try:
tokenizer.name_or_path = requested_name_or_path
except Exception:
init_kwargs = getattr(tokenizer, "init_kwargs", None)
if isinstance(init_kwargs, dict):
init_kwargs["name_or_path"] = requested_name_or_path

if getattr(tokenizer, "name_or_path", "") != requested_name_or_path:
raise RuntimeError(
f"Loaded tokenizer for {requested_name_or_path!r} from "
f"{loaded_name_or_path!r}, but could not preserve the requested "
"name_or_path for renderer auto-resolution."
)
return tokenizer


def _patched_load(model_name_or_path: str, **kwargs):
"""Run ``AutoTokenizer.from_pretrained`` with fastokens patched in
process-locally — patch around the load, unpatch right after.
Expand Down Expand Up @@ -1312,29 +1359,41 @@ def load_tokenizer(
validation for configs with nested ``rope_parameters``), we fall
back to loading the repo's self-contained ``tokenizer.json``
directly — see ``_load_tokenizer_via_auto``.
"""
kwargs: dict[str, Any] = {}
revision = TRUSTED_REVISIONS.get(model_name_or_path)
if revision is not None:
kwargs = {"trust_remote_code": True, "revision": revision}
else:
kwargs = {"trust_remote_code": False}

if not use_fastokens or model_name_or_path in FASTOKENS_INCOMPATIBLE:
return _load_tokenizer_via_auto(model_name_or_path, **kwargs)
Canonical Meta Llama-3.2 Instruct repos are gated on HuggingFace. For
those exact IDs we load tokenizer files from the audited unrestricted
``unsloth`` mirrors instead, then restore ``tokenizer.name_or_path`` to
the requested Meta ID so auto-resolution still selects ``Llama3Renderer``.
"""
load_name_or_path = _tokenizer_source_for(model_name_or_path)
kwargs = _tokenizer_load_kwargs(load_name_or_path)

if not use_fastokens or load_name_or_path in FASTOKENS_INCOMPATIBLE:
tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
return _preserve_requested_tokenizer_name(
tok,
requested_name_or_path=model_name_or_path,
loaded_name_or_path=load_name_or_path,
)

try:
return _patched_load(model_name_or_path, **kwargs)
tok = _patched_load(load_name_or_path, **kwargs)
except Exception as exc:
logger.info(
"fastokens could not load %r (%s: %s); falling back to vanilla "
"AutoTokenizer. Add this model to FASTOKENS_INCOMPATIBLE in "
"renderers.base to suppress the retry.",
model_name_or_path,
load_name_or_path,
type(exc).__name__,
str(exc)[:160],
)
return _load_tokenizer_via_auto(model_name_or_path, **kwargs)
tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)

return _preserve_requested_tokenizer_name(
tok,
requested_name_or_path=model_name_or_path,
loaded_name_or_path=load_name_or_path,
)


def _populate_registry():
Expand Down Expand Up @@ -1702,12 +1761,8 @@ def _get_offset_tokenizer(tokenizer):
if cached is not None:
return cached

kwargs: dict[str, Any] = {}
revision = TRUSTED_REVISIONS.get(name_or_path)
if revision is not None:
kwargs = {"trust_remote_code": True, "revision": revision}
else:
kwargs = {"trust_remote_code": False}
load_name_or_path = _tokenizer_source_for(name_or_path)
kwargs = _tokenizer_load_kwargs(load_name_or_path)

def _has_offsets(tok) -> bool:
if not getattr(tok, "is_fast", False):
Expand All @@ -1727,7 +1782,12 @@ def _has_offsets(tok) -> bool:
# off — serialized against pool patch/unpatch via ``_FASTOKENS_PATCH_LOCK``
# so no concurrent window can swap the shim back in mid-load — then
# restore the prior patch state. Never cache a non-offset tokenizer.
offset_tok = _load_tokenizer_via_auto(name_or_path, **kwargs)
offset_tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
offset_tok = _preserve_requested_tokenizer_name(
offset_tok,
requested_name_or_path=name_or_path,
loaded_name_or_path=load_name_or_path,
)
if not _has_offsets(offset_tok):
import fastokens

Expand All @@ -1737,7 +1797,12 @@ def _has_offsets(tok) -> bool:
with contextlib.redirect_stdout(io.StringIO()):
fastokens.unpatch_transformers()
try:
offset_tok = _load_tokenizer_via_auto(name_or_path, **kwargs)
offset_tok = _load_tokenizer_via_auto(load_name_or_path, **kwargs)
offset_tok = _preserve_requested_tokenizer_name(
offset_tok,
requested_name_or_path=name_or_path,
loaded_name_or_path=load_name_or_path,
)
finally:
if was_patched:
with contextlib.redirect_stdout(io.StringIO()):
Expand Down
11 changes: 5 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@
# there's just no byte-output to parity-check against. Split-specific
# parity (V3 bare prompt vs R1 <think>+history-strip) is covered in
# tests/test_deepseek_r1.py.
# Llama-3 loads via the unrestricted unsloth mirror (byte-identical
# chat template) so CI needs no Meta-gated HF token. Pinned to the
# explicit "llama-3" config because the mirror name isn't in
# MODEL_RENDERER_MAP (so "auto" would fall back to DefaultRenderer).
("unsloth/Llama-3.2-1B-Instruct", "llama-3"),
# Llama-3 uses the canonical Meta ID for renderer auto-resolution, while
# load_tokenizer fetches the tokenizer/chat_template from the unrestricted
# unsloth mirror so CI needs no Meta-gated HF token.
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
("openai/gpt-oss-20b", "gpt-oss"),
("Qwen/Qwen2.5-0.5B-Instruct", "default"),
]
Expand Down Expand Up @@ -139,7 +138,7 @@ def _skip_gpt_oss_for_hf_parity_tests(request):
def _skip_llama_for_hf_parity_tests(request):
callspec = getattr(request.node, "callspec", None)
model_name = callspec.params.get("model_name") if callspec else None
if model_name != "unsloth/Llama-3.2-1B-Instruct":
if model_name != "meta-llama/Llama-3.2-1B-Instruct":
return
test_file = os.path.basename(str(request.node.fspath))
if test_file in _LLAMA_HF_PARITY_TEST_FILES:
Expand Down
18 changes: 13 additions & 5 deletions tests/test_llama_3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Llama-3 renderer coverage.

Covers ``Llama3Renderer`` and the ``meta-llama/Llama-3.2-{1B,3B}-Instruct``
entries in ``MODEL_RENDERER_MAP``. Tokenizers are loaded via the
unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirrors (verified
byte-identical chat templates) so CI doesn't need an HF token with Meta
license access.
entries in ``MODEL_RENDERER_MAP``. ``load_tokenizer`` uses the
unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirrors underneath
(verified byte-identical chat templates) so CI doesn't need an HF token
with Meta license access.
"""

from __future__ import annotations
Expand Down Expand Up @@ -34,7 +34,7 @@
@pytest.fixture(scope="module", params=_MODEL_PAIRS, ids=[m for m, _ in _MODEL_PAIRS])
def llama_pair(request):
canonical, mirror = request.param
tok = load_tokenizer(mirror)
tok = load_tokenizer(canonical)
renderer = Llama3Renderer(tok, Llama3RendererConfig(date_string=_PINNED_DATE))
return canonical, mirror, tok, renderer

Expand All @@ -58,6 +58,14 @@ def test_create_renderer_via_explicit_config(llama_pair):
assert isinstance(r, Llama3Renderer)


def test_create_renderer_auto_resolves_after_mirror_load(llama_pair):
"""``load_tokenizer(canonical_meta_id)`` loads from the unrestricted
mirror but preserves the canonical name needed for auto-resolution."""
canonical, _, tok, _ = llama_pair
assert tok.name_or_path == canonical
assert isinstance(create_renderer(tok), Llama3Renderer)


# ---------------------------------------------------------------------------
# Constructor contract
# ---------------------------------------------------------------------------
Expand Down
68 changes: 67 additions & 1 deletion tests/test_load_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from __future__ import annotations

import re
from types import SimpleNamespace
from unittest.mock import patch

from renderers.base import TRUSTED_REVISIONS, load_tokenizer
from renderers import base
from renderers.base import TOKENIZER_SOURCE_OVERRIDES, TRUSTED_REVISIONS, load_tokenizer


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +72,23 @@ def test_kimi_loads_with_pinned_revision(mock_from_pretrained):
}


@patch("transformers.AutoTokenizer.from_pretrained")
def test_meta_llama_loads_tokenizer_from_unsloth_mirror(mock_from_pretrained):
"""Canonical Meta Llama repos are gated; load their tokenizer/chat
template from the audited unrestricted mirror while preserving the
canonical name for renderer auto-resolution."""
canonical = "meta-llama/Llama-3.2-1B-Instruct"
mirror = "unsloth/Llama-3.2-1B-Instruct"
mock_from_pretrained.return_value = SimpleNamespace(name_or_path=mirror)

tok = load_tokenizer(canonical, use_fastokens=False)

args, kwargs = mock_from_pretrained.call_args
assert args == (mirror,)
assert kwargs == {"trust_remote_code": False}
assert tok.name_or_path == canonical


@patch("transformers.AutoTokenizer.from_pretrained")
def test_unknown_path_falls_through_to_no_remote_code(mock_from_pretrained):
"""Unknown / fine-tuned model paths — including ``moonshotai/Kimi-K2*``
Expand All @@ -92,6 +111,53 @@ def test_unknown_path_falls_through_to_no_remote_code(mock_from_pretrained):
)


def test_tokenizer_source_overrides_are_exact_llama_mirrors():
"""Mirror overrides are intentionally narrow: only verified
byte-identical Llama tokenizer/template mirrors should live here."""
assert TOKENIZER_SOURCE_OVERRIDES == {
"meta-llama/Llama-3.2-1B-Instruct": "unsloth/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct": "unsloth/Llama-3.2-3B-Instruct",
}


def test_offset_tokenizer_uses_unsloth_mirror_for_meta_llama(monkeypatch):
"""Offset-tokenizer reloads must use the same unrestricted source
override, otherwise Llama rendering can hit the gated Meta repo after
the initial tokenizer load succeeds."""

class _NoOffsets:
name_or_path = "meta-llama/Llama-3.2-1B-Instruct"

def __call__(self, *args, **kwargs):
raise NotImplementedError("fastokens shim has no offsets")

class _OffsetTokenizer:
is_fast = True

def __init__(self, name_or_path: str):
self.name_or_path = name_or_path

def __call__(self, *args, **kwargs):
return {"offset_mapping": [(0, 1)]}

calls = []

def _fake_load(name_or_path, **kwargs):
calls.append((name_or_path, kwargs))
return _OffsetTokenizer(name_or_path)

base._offset_tokenizers.clear()
monkeypatch.setattr(base, "_load_tokenizer_via_auto", _fake_load)

try:
tok = base._get_offset_tokenizer(_NoOffsets())
finally:
base._offset_tokenizers.clear()

assert calls == [("unsloth/Llama-3.2-1B-Instruct", {"trust_remote_code": False})]
assert tok.name_or_path == "meta-llama/Llama-3.2-1B-Instruct"


# ---------------------------------------------------------------------------
# Smoke: real tokenizer loads behave as expected
# ---------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions tests/test_preserve_thinking.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def _make(tokenizer, renderer_name, **flags):
"poolside/Laguna-XS.2",
# Llama-3 has no reasoning channel at all — preserve flags can't add
# or drop anything, so they're pure no-ops.
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"unsloth/Llama-3.2-1B-Instruct",
}

Expand Down Expand Up @@ -324,6 +326,8 @@ def test_preserve_btc_on_live_cycle_matches_all(
"Qwen/Qwen3-VL-30B-A3B-Instruct",
# Llama-3 ships no <think> rendering path, so reasoning_content never
# surfaces in the output regardless of the preserve flags.
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"unsloth/Llama-3.2-1B-Instruct",
}

Expand Down
Loading