Add Gemma4 12b unified model support in Fairseq2#1519
Open
YunchaoYang wants to merge 11 commits into
Open
Conversation
added 10 commits
June 5, 2026 21:18
* config.py: register arch('12b') and arch('12b_it'); add get_gemma4_12b_config().
Dense Gemma 4 text decoder reusing the gemma4 code paths with:
hidden=3840, layers=48 (5:1 sliding:full = 40+8), heads=16, kv=8,
num_global_kv=1 (MQA), head_dim=256, global_head_dim=512,
ffn_inner=15360, sliding_window=1024, max_seq_len=262144,
attention_k_eq_v=True, no PLE, no KV sharing, no MoE,
final_logit_softcapping=30.
* __init__.py: export get_gemma4_12b_config.
* interop.py: extend always-strip multimodal prefixes with 'model.vision_embedder.'
(Unified family ships a tower-free vision embedder: LN+Dense+LN+factorized
2D posemb+LN+RMSNorm+Linear). Audio remains conditional on audio_config.
Verifies (offline): for google/gemma-4-12B safetensors (677 HF keys),
convert_gemma4_state_dict produces 667 keys, matching the 667 keys of
the freshly-constructed fairseq2 12B model on meta device exactly (no
missing, no extras).
* assets/cards/models/gemma4.yaml: register gemma4_12b and gemma4_12b_it cards (model_arch=12b/12b_it, checkpoint=hg://google/gemma-4-12B[-it], tokenizer_family=gemma4 — the unified family reuses the same HuggingFace tokenizer infrastructure and vocab). * recipes/lm/sft/configs/gemma4_12b_gsm8k.yaml: mirror of gemma4_e4b_gsm8k.yaml with name=gemma4_12b, max_seq_len=4096, bf16 FSDP. Same optimizer/scheduler as e4b. Smoke-test target.
The base google/gemma-4-12B repo does not ship chat_template.jinja (only the -it variant does). For chat_mode=true SFT we need the chat template, so point the tokenizer at the -it asset card. The model weights are unchanged (the YAML's model.name=gemma4_12b is still the base checkpoint); only the tokenizer asset differs to pick up the chat template.
…for SFT
* assets/cards/models/gemma4.yaml: gemma4_12b and gemma4_12b_it now
use file:///checkpoint/smallomnillm/shared/models/gemma-4-12B[-it]
paths (user-provided shared mirror), avoiding HF Hub cache races
and the missing chat_template.jinja issue.
* recipes/lm/sft/configs/gemma4_12b_gsm8k.yaml: set chat_mode=false.
Google's Gemma 4 chat_template.jinja does NOT use {% generation %}
markers, so apply_chat_template(return_assistant_tokens_mask=True)
returns all-zero assistant_masks and lm_sft's target_mask becomes
all-false. With chat_mode=true the SFT runs but Number of Target
Elements = 0 every step, NLL Loss stays at 0, and no gradient
flows. Use chat_mode=false (LM continuation on src+tgt) to validate
the recipe wiring end-to-end with a nonzero loss signal; revisit
once the chat template is patched upstream to include
{% generation %} blocks around the model turn.
The base google/gemma-4-12B is at /checkpoint/fairseq2/shared/models (downloaded by exp 42); only the -it variant lives at the user-shared /checkpoint/smallomnillm tree. Switching the SFT yaml to fine-tune the -it variant (standard domain-adaptation pattern) so the SFT recipe finds its checkpoint in the user's tree.
The SFT recipe (chat_mode=false path) calls create_encoder twice per example: mode='prompt' for the source and mode='prompt_response' for the target. Gemma4Tokenizer previously only accepted default/prompt/ as_is and raised ValueError on 'prompt_response', blocking the SFT recipe. Add prompt_response: no BOS prefix (the source half already has it), EOS suffix to terminate the target. Matches the Llama/Qwen pattern.
Recipe-level wandb recorder writes per-step training metrics (NLL Loss, Gradient Norm, LR, throughput) into a wandb run with project= gemma-4-fairseq2. With WANDB_MODE=offline (default in our launch env since no API key is configured), the run materializes locally under <output_dir>/wandb/ and can be synced later via 'wandb sync <path>'.
* audio/config.py: add audio_mode field ('conformer' | 'linear').
'conformer' (default) preserves existing E4B / classic-Gemma 4
behaviour (mel-spec -> subsample Conv2d -> Conformer tower ->
embedder). 'linear' is the new Gemma 4 Unified pipeline (raw
waveform frames of audio_samples_per_token=640 -> embedder
RMSNorm+Linear). Other Conformer-specific fields are ignored
in linear mode.
* config.py: add get_gemma4_unified_audio_config() (linear mode,
output_proj_dims=640) and get_gemma4_12b_audio_config() that
wraps the existing 12B text config with audio enabled. Register
archs '12b_audio' and '12b_it_audio'. The original '12b' / '12b_it'
archs stay text-only (no behaviour change for existing parity /
MMLU / SFT runs).
* factory.py: create_audio_tower returns None when audio_mode='linear'
(the existing Gemma4MultimodalAudioEmbedder class is already exactly
the unified embedder: RMSNorm(no scale) + Linear from 640 to text dim,
no other changes needed).
* model.py: Gemma4Model.forward now also handles the embedder-only
case (tower=None, embedder!=None): feeds raw audio_features
directly through the embedder. The conformer path is unchanged.
Sanity check: 12b_audio model has 668 keys; converter on the HF
gemma-4-12B-it safetensors produces exactly 668 keys (667 text +
audio_embedder.embedding_projection.weight) -- zero diff.
…rest of the family User moved the 12B base safetensors from /checkpoint/fairseq2/shared/models to /checkpoint/smallomnillm/shared/models so all Gemma 4 variants live together (gemma-4-E*B, gemma-4-31B[-it], gemma-4-26B-A4B[-it], gemma-4-12B, gemma-4-12B-it). Source path deleted; asset card now points at the new location. The -it card was already on smallomnillm.
Mirrors HF's split into two classes (Gemma4MultimodalEmbedder vs Gemma4UnifiedMultimodalEmbedder), but in a single fairseq2 class gated by a ctor flag. When cast_input_dtype=True, forward casts inputs_embeds to self.embedding_projection.weight.dtype before the RMSNorm -- exactly what HF's Gemma4UnifiedMultimodalEmbedder.forward does (modular_gemma4_unified.py:895-899). Without the cast, raw waveform features (fp32 from the feature extractor) fed into a bf16 embedder would silently use mismatched dtypes. Factory wires cast_input_dtype=True when audio_config.audio_mode == 'linear' (the Gemma 4 Unified family path). The default flag is False, so the classic E*/31B/26B-A4B Conformer audio path keeps its bit-for-bit behaviour. Exp 49 audio parity (bf16 cos=0.9992) passed before this change because the parity script pre-cast audio_features to DTYPE manually; this commit moves the cast into the embedder so any caller that forgets to pre-cast still produces HF-identical output.
zyaoj
requested changes
Jun 17, 2026
| name: gemma4_12b | ||
| model_family: gemma4 | ||
| model_arch: 12b | ||
| checkpoint: "file:///checkpoint/smallomnillm/shared/models/gemma-4-12B" |
Contributor
There was a problem hiding this comment.
Let's put this under ext instead of exposing in the OS. Use hg://* model cards.
| return None | ||
|
|
||
| # Unified family: no tower; embedder consumes raw waveform frames. | ||
| if getattr(config.audio_config, "audio_mode", "conformer") == "linear": |
Contributor
There was a problem hiding this comment.
why getattr? is audio_mode a proper dataclass or there's a BC requirement for this field?
| and ``rms_norm_eps`` are read. | ||
| """ | ||
|
|
||
| audio_mode: str = "conformer" |
Contributor
There was a problem hiding this comment.
do we want to use Literal["conformer", "linear"]?
Contributor
Author
|
@claude review this PR please |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do? Please describe:
A summary of the change or the issue that is fixed.
Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: