diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 0eceb23..4b22bec 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -443,7 +443,8 @@ def _set_state_dict(self, to_mcore: bool, *, offset: float = 0, - is_expert: bool = False): + is_expert: bool = False, + _check_mg_param: bool = True): if '.' in mg_key: module_key, param_key = mg_key.rsplit('.', 1) else: @@ -491,7 +492,11 @@ def _set_state_dict(self, else: mg_param = deep_getattr(sub_module, param_key) if to_mcore: - assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' + if mg_param is None: + if _check_mg_param: + raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') + else: + return hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' @@ -1618,13 +1623,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + self._set_word_embeddings(mg_model, hf_state_dict, to_mcore) if self.is_multimodal: for prefix, mg_prefix in self.module_mapping.items(): mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 869f876..f6968a2 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -153,6 +153,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type in {'gemma4'}: + res['qk_layernorm'] = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 58b0ee3..9b8dc1b 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -29,6 +29,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 13117e6..8a5856c 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -105,9 +105,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -217,7 +215,35 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + + return (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -252,26 +278,7 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) - - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -302,7 +309,7 @@ def forward( if self.config.position_embedding_type == 'mrope' and position_ids.ndim == 2: # qwen3_asr position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) inference_context = deprecate_inference_params(inference_context, inference_params) - + # There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP. decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, @@ -315,7 +322,11 @@ def forward( packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if isinstance(rotary_pos_emb, dict): + for k, v in rotary_pos_emb.items(): + decoder_rotary_pos_emb[k] = v[position_ids[0]] + else: + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index 9009edb..b862ec6 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..e9b9ece --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,450 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import math +import torch +import torch.distributed as dist +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.tensor_parallel import VocabParallelEmbedding +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import build_module +from transformers import AutoModel, PretrainedConfig +from typing import Optional + +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig + +from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel +from ..modules import CustomTransformerBlock, CustomTransformerLayer +from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq +from .utils import HuggingFaceVit + + +class Gemma4VNorm(torch.nn.Module): + """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + return (x * torch.rsqrt(variance + self.eps)).to(orig_dtype) + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + dtype = self.vision_tower.dtype + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) + if hf_config.audio_config is not None else None) + self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) + self.model_cls = Gemma4Model + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + input_ids = kwargs.get('input_ids') + inputs_embeds *= self.embed_scale.to(inputs_embeds.dtype) + + hf_config = self.hf_config + input_ids = kwargs.get('input_ids') + pixel_values = kwargs.get('pixel_values') + pixel_values_videos = kwargs.get('pixel_values_videos') + input_features = kwargs.get('input_features') + input_features_mask = kwargs.get('input_features_mask') + image_position_ids = kwargs.get('image_position_ids') + video_position_ids = kwargs.get('video_position_ids') + + image_mask = input_ids == hf_config.image_token_id + video_mask = input_ids == hf_config.video_token_id + audio_mask = input_ids == hf_config.audio_token_id + multimodal_mask = image_mask | video_mask | audio_mask + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id + + if pixel_values is not None: + with self.patch_hf_config(): + image_features = self.model_cls.get_image_features( + self, pixel_values, image_position_ids, return_dict=True).pooler_output + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) + + if pixel_values_videos is not None: + with self.patch_hf_config(): + video_features = self.get_video_features( + pixel_values_videos, video_position_ids, return_dict=True).pooler_output + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) + + if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): + with self.patch_hf_config(): + audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_features = audio_features[audio_output.attention_mask] + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) + return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids} + + +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + layer_idx = layer_number - 1 + + # Layer type / sliding attention + self.layer_type = text_config.layer_types[layer_idx] + self.is_sliding = self.layer_type == 'sliding_attention' + self.sliding_window = text_config.sliding_window if self.is_sliding else None + + # Head dim: global layers may use a different head dim than sliding ones + self.head_dim = ( + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) + + # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set + self.use_alternative_attention = (getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + num_key_value_heads = ( + text_config.num_global_key_value_heads + if self.use_alternative_attention else text_config.num_key_value_heads) + self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads + + # Shared KV across the trailing layers + num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, reuse KV from the last non-shared layer of the same type + self.kv_shared_layer_index = (len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # Non-shared layers that are the last of their type in `prev_layers` must keep full KV + self.store_full_length_kv = ( + self.layer_type in prev_layers + and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + + orig_kv_channels = config.kv_channels + orig_num_query_groups = config.num_query_groups + orig_k_layernorm = submodules.k_layernorm + config.kv_channels = self.head_dim + config.num_query_groups = num_key_value_heads + if self.is_kv_shared_layer: + submodules.k_layernorm = IdentityOp + try: + super().__init__(config, submodules, layer_number, *args, **kwargs) + finally: + config.kv_channels = orig_kv_channels + config.num_query_groups = orig_num_query_groups + submodules.k_layernorm = orig_k_layernorm + + if self.is_kv_shared_layer: + self.linear_qkv_out_dim = self.query_projection_size + self.linear_qkv = submodules.linear_qkv( + self.config.hidden_size, + self.linear_qkv_out_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.pg_collection.tp, + ) + + self.v_norm = ( + Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) + + +class Gemma4MLP(MLP): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + self.layer_number = layer_number + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 + use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer + ffn_hidden_size = config.ffn_hidden_size + config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + config.ffn_hidden_size = ffn_hidden_size + + +class Gemma4Bridge(MultimodalGPTBridge): + hf_post_attention_layernorm = 'pre_feedforward_layernorm' + additional_dim0_keys = {'per_layer_input_gate', 'per_layer_model_projection'} + additional_dim1_keys = {'per_layer_projection'} + + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + self._set_state_dict( + mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore, _check_mg_param=False) + self._set_state_dict( + mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) + + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer + is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) + is_kv_shared_layer = is_kv_shared_layer.item() + if is_kv_shared_layer: + self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) + return hf_state_dict + else: + return super()._set_qkv(mg_attn, hf_state_dict, to_mcore) + + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + for key in [ + 'post_attention_layernorm', 'post_feedforward_layernorm', 'per_layer_input_gate', + 'per_layer_projection', 'post_per_layer_input_norm' + ]: + self._set_state_dict( + mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore, _check_mg_param=False) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: + self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', + to_mcore) + + +class Gemma4TextGPTModel(GPTModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + text_config = self.config.hf_config.text_config + self.text_config = text_config + self.unique_layer_types = set(text_config.layer_types) + + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input and self.pre_process: + num_layers = text_config.num_hidden_layers + hidden_size = text_config.hidden_size + total_dim = num_layers * self.hidden_size_per_layer_input + tp_size = self.config.tensor_model_parallel_size + padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size + self.embed_tokens_per_layer = VocabParallelEmbedding( + num_embeddings=padded_vocab_size_per_layer, + embedding_dim=total_dim, + init_method=self.config.init_method, + config=self.config, + tp_group=self.pg_collection.tp, + ) + self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 + self.per_layer_input_scale = 2.0**-0.5 + self.per_layer_model_projection = build_module( + TEColumnParallelLinear, + hidden_size, + total_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_model_projection', + tp_group=self.pg_collection.tp, + ) + self.per_layer_model_projection_scale = hidden_size**-0.5 + self.per_layer_projection_norm = build_module( + TENorm, + hidden_size=self.hidden_size_per_layer_input, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input, + self.config, packed_seq_params) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + full_rotary_pos_emb = self.full_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = {'sliding_attention': rotary_pos_emb, 'full_attention': full_rotary_pos_emb} + return rotary_pos_emb, None, None + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + def forward(self, *args, **kwargs): + extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} + llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) + if self.hidden_size_per_layer_input and self.pre_process: + per_layer_inputs = (self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale).reshape( + *llm_input_ids.shape, + self.text_config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + kwargs['extra_block_kwargs'] = extra_block_kwargs + return super().forward(*args, **kwargs) + + +class Gemma4TransformerLayer(CustomTransformerLayer): + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + text_config = config.hf_config.text_config + hidden_size = self.config.hidden_size + eps = self.config.layernorm_epsilon + + self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + self.register_buffer('layer_scalar', torch.ones(1)) + + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input: + from transformers.activations import ACT2FN + self.act_fn = ACT2FN[text_config.hidden_activation] + self.per_layer_input_gate = build_module( + TEColumnParallelLinear, + hidden_size, + self.hidden_size_per_layer_input, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_input_gate', + tp_group=self.pg_collection.tp, + ) + self.per_layer_projection = build_module( + TERowParallelLinear, + self.hidden_size_per_layer_input, + hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_projection', + tp_group=self.pg_collection.tp, + ) + self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) + if self.enable_moe_block: + self.post_feedforward_layernorm_1 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.pre_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + def forward(self, *args, **kwargs): + per_layer_input = kwargs.pop('per_layer_input', None) + output, context = super().forward(*args, **kwargs) + return output, context + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + +class Gemma4TransformerBlock(CustomTransformerBlock): + + def _layer_forward(self, layer, hidden_states, **kwargs): + layer_number = layer.layer_number - 1 + per_layer_inputs = kwargs.pop('per_layer_inputs', None) + kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] + return super()._layer_forward(layer, hidden_states, **kwargs) + + +class Gemma4Loader(ModelLoader): + model_cls = Gemma4GPTModel + transformer_block = Gemma4TransformerBlock + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP + return layer_specs + + def _set_custom_layer(self, transformer_layer_spec): + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = Gemma4TransformerLayer + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 21b8cfa..2520bd8 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -216,6 +216,7 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs = {} # import here to avoid circular import from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit @@ -230,6 +231,8 @@ def _build_mlp(self, mlp_spec): elif mlp_spec.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif mlp_spec.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = self.layer_number elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp diff --git a/tests/test_mllm.py b/tests/test_mllm.py index 1fe103e..5eb07d2 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -116,6 +116,10 @@ def test_qwen3_asr(): _test_model('Qwen/Qwen3-ASR-1.7B') +def test_gemma4(): + _test_model('google/gemma-4-E2B-it') + + if __name__ == '__main__': # test_qwen2_5_vl() # test_qwen2_vl() @@ -136,4 +140,5 @@ def test_qwen3_asr(): # test_llama4() # test_qwen3_5() # test_llava_onevision1_5() - test_qwen3_asr() + # test_qwen3_asr() + test_gemma4()