Skip to content

Add Gemma4 12b unified model support in Fairseq2#1519

Open
YunchaoYang wants to merge 11 commits into
mainfrom
yy/gemma4-12b-unified
Open

Add Gemma4 12b unified model support in Fairseq2#1519
YunchaoYang wants to merge 11 commits into
mainfrom
yy/gemma4-12b-unified

Conversation

@YunchaoYang

@YunchaoYang YunchaoYang commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

What does this PR do? Please describe:
A summary of the change or the issue that is fixed.

Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

ProfAI added 10 commits June 5, 2026 21:18
* config.py: register arch('12b') and arch('12b_it'); add get_gemma4_12b_config().
  Dense Gemma 4 text decoder reusing the gemma4 code paths with:
    hidden=3840, layers=48 (5:1 sliding:full = 40+8), heads=16, kv=8,
    num_global_kv=1 (MQA), head_dim=256, global_head_dim=512,
    ffn_inner=15360, sliding_window=1024, max_seq_len=262144,
    attention_k_eq_v=True, no PLE, no KV sharing, no MoE,
    final_logit_softcapping=30.
* __init__.py: export get_gemma4_12b_config.
* interop.py: extend always-strip multimodal prefixes with 'model.vision_embedder.'
  (Unified family ships a tower-free vision embedder: LN+Dense+LN+factorized
  2D posemb+LN+RMSNorm+Linear). Audio remains conditional on audio_config.

Verifies (offline): for google/gemma-4-12B safetensors (677 HF keys),
convert_gemma4_state_dict produces 667 keys, matching the 667 keys of
the freshly-constructed fairseq2 12B model on meta device exactly (no
missing, no extras).
* assets/cards/models/gemma4.yaml: register gemma4_12b and gemma4_12b_it
  cards (model_arch=12b/12b_it, checkpoint=hg://google/gemma-4-12B[-it],
  tokenizer_family=gemma4 — the unified family reuses the same
  HuggingFace tokenizer infrastructure and vocab).
* recipes/lm/sft/configs/gemma4_12b_gsm8k.yaml: mirror of
  gemma4_e4b_gsm8k.yaml with name=gemma4_12b, max_seq_len=4096,
  bf16 FSDP. Same optimizer/scheduler as e4b. Smoke-test target.
The base google/gemma-4-12B repo does not ship chat_template.jinja
(only the -it variant does). For chat_mode=true SFT we need the chat
template, so point the tokenizer at the -it asset card. The model
weights are unchanged (the YAML's model.name=gemma4_12b is still the
base checkpoint); only the tokenizer asset differs to pick up the
chat template.
…for SFT

* assets/cards/models/gemma4.yaml: gemma4_12b and gemma4_12b_it now
  use file:///checkpoint/smallomnillm/shared/models/gemma-4-12B[-it]
  paths (user-provided shared mirror), avoiding HF Hub cache races
  and the missing chat_template.jinja issue.
* recipes/lm/sft/configs/gemma4_12b_gsm8k.yaml: set chat_mode=false.
  Google's Gemma 4 chat_template.jinja does NOT use {% generation %}
  markers, so apply_chat_template(return_assistant_tokens_mask=True)
  returns all-zero assistant_masks and lm_sft's target_mask becomes
  all-false. With chat_mode=true the SFT runs but Number of Target
  Elements = 0 every step, NLL Loss stays at 0, and no gradient
  flows. Use chat_mode=false (LM continuation on src+tgt) to validate
  the recipe wiring end-to-end with a nonzero loss signal; revisit
  once the chat template is patched upstream to include
  {% generation %} blocks around the model turn.
The base google/gemma-4-12B is at /checkpoint/fairseq2/shared/models
(downloaded by exp 42); only the -it variant lives at the user-shared
/checkpoint/smallomnillm tree. Switching the SFT yaml to fine-tune
the -it variant (standard domain-adaptation pattern) so the SFT
recipe finds its checkpoint in the user's tree.
The SFT recipe (chat_mode=false path) calls create_encoder twice per
example: mode='prompt' for the source and mode='prompt_response' for
the target. Gemma4Tokenizer previously only accepted default/prompt/
as_is and raised ValueError on 'prompt_response', blocking the SFT
recipe.

Add prompt_response: no BOS prefix (the source half already has it),
EOS suffix to terminate the target. Matches the Llama/Qwen pattern.
Recipe-level wandb recorder writes per-step training metrics (NLL Loss,
Gradient Norm, LR, throughput) into a wandb run with project=
gemma-4-fairseq2. With WANDB_MODE=offline (default in our launch env
since no API key is configured), the run materializes locally under
<output_dir>/wandb/ and can be synced later via 'wandb sync <path>'.
* audio/config.py: add audio_mode field ('conformer' | 'linear').
  'conformer' (default) preserves existing E4B / classic-Gemma 4
  behaviour (mel-spec -> subsample Conv2d -> Conformer tower ->
  embedder). 'linear' is the new Gemma 4 Unified pipeline (raw
  waveform frames of audio_samples_per_token=640 -> embedder
  RMSNorm+Linear). Other Conformer-specific fields are ignored
  in linear mode.
* config.py: add get_gemma4_unified_audio_config() (linear mode,
  output_proj_dims=640) and get_gemma4_12b_audio_config() that
  wraps the existing 12B text config with audio enabled. Register
  archs '12b_audio' and '12b_it_audio'. The original '12b' / '12b_it'
  archs stay text-only (no behaviour change for existing parity /
  MMLU / SFT runs).
* factory.py: create_audio_tower returns None when audio_mode='linear'
  (the existing Gemma4MultimodalAudioEmbedder class is already exactly
  the unified embedder: RMSNorm(no scale) + Linear from 640 to text dim,
  no other changes needed).
* model.py: Gemma4Model.forward now also handles the embedder-only
  case (tower=None, embedder!=None): feeds raw audio_features
  directly through the embedder. The conformer path is unchanged.

Sanity check: 12b_audio model has 668 keys; converter on the HF
gemma-4-12B-it safetensors produces exactly 668 keys (667 text +
audio_embedder.embedding_projection.weight) -- zero diff.
…rest of the family

User moved the 12B base safetensors from /checkpoint/fairseq2/shared/models
to /checkpoint/smallomnillm/shared/models so all Gemma 4 variants live
together (gemma-4-E*B, gemma-4-31B[-it], gemma-4-26B-A4B[-it], gemma-4-12B,
gemma-4-12B-it). Source path deleted; asset card now points at the new
location. The -it card was already on smallomnillm.
Mirrors HF's split into two classes (Gemma4MultimodalEmbedder vs
Gemma4UnifiedMultimodalEmbedder), but in a single fairseq2 class
gated by a ctor flag. When cast_input_dtype=True, forward casts
inputs_embeds to self.embedding_projection.weight.dtype before the
RMSNorm -- exactly what HF's Gemma4UnifiedMultimodalEmbedder.forward
does (modular_gemma4_unified.py:895-899). Without the cast, raw
waveform features (fp32 from the feature extractor) fed into a bf16
embedder would silently use mismatched dtypes.

Factory wires cast_input_dtype=True when audio_config.audio_mode
== 'linear' (the Gemma 4 Unified family path). The default flag is
False, so the classic E*/31B/26B-A4B Conformer audio path keeps its
bit-for-bit behaviour.

Exp 49 audio parity (bf16 cos=0.9992) passed before this change
because the parity script pre-cast audio_features to DTYPE
manually; this commit moves the cast into the embedder so any
caller that forgets to pre-cast still produces HF-identical output.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2026
@YunchaoYang YunchaoYang changed the title Yy/gemma4 12b unified Add Gemma4 12b unified model support in Fairseq2 Jun 8, 2026
@YunchaoYang YunchaoYang marked this pull request as ready for review June 16, 2026 13:18
name: gemma4_12b
model_family: gemma4
model_arch: 12b
checkpoint: "file:///checkpoint/smallomnillm/shared/models/gemma-4-12B"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this under ext instead of exposing in the OS. Use hg://* model cards.

return None

# Unified family: no tower; embedder consumes raw waveform frames.
if getattr(config.audio_config, "audio_mode", "conformer") == "linear":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why getattr? is audio_mode a proper dataclass or there's a BC requirement for this field?

and ``rms_norm_eps`` are read.
"""

audio_mode: str = "conformer"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to use Literal["conformer", "linear"]?

@YunchaoYang

Copy link
Copy Markdown
Contributor Author

@claude review this PR please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants