diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index c380328..34961d0 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -40,9 +40,12 @@ class GPTBridge: hf_o_proj_key = 'o_proj' hf_attn_prefix = 'self_attn' hf_mlp_prefix = 'mlp' + hf_post_attention_layernorm = 'post_attention_layernorm' hf_gate_key = 'gate.weight' hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' + additional_dim0_keys = set() + additional_dim1_keys = set() def __init__(self, config: ModelConfig): self.config = config @@ -124,11 +127,11 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: 'linear_kv_up_proj', # mtp 'eh_proj', - } + } | self.additional_dim0_keys if self.config.task_type in {'causal_lm', 'generative_reranker'}: dim0_keys.add('output_layer') # RowLinear - dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} + dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} | self.additional_dim1_keys if 'lora_A' not in mg_key and 'lora_B' not in mg_key: key, suffix = mg_key.rsplit('.', 2)[-2:] if suffix == 'layer_norm_weight': @@ -1587,13 +1590,13 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool hf_state_dict.update( self._set_moe_state( mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', - to_mcore) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) else: hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) + f'{self.hf_post_attention_layernorm}.weight', to_mcore) return hf_state_dict def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index c03f803..7830215 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -27,9 +27,11 @@ def __init__( k_layernorm = submodules.k_layernorm submodules.q_layernorm = IdentityOp submodules.k_layernorm = IdentityOp - super().__init__(config, submodules, *args, **kwargs) - submodules.q_layernorm = q_layernorm - submodules.k_layernorm = k_layernorm + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + submodules.q_layernorm = q_layernorm + submodules.k_layernorm = k_layernorm self.q_norm = build_module( submodules.q_layernorm, hidden_size=self.hidden_size_per_attention_head * config.num_attention_heads, diff --git a/src/mcore_bridge/model/gpts/qwen3_next.py b/src/mcore_bridge/model/gpts/qwen3_next.py index 316a7f0..1bf4b2d 100644 --- a/src/mcore_bridge/model/gpts/qwen3_next.py +++ b/src/mcore_bridge/model/gpts/qwen3_next.py @@ -517,8 +517,8 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool mg_mlp = None if mg_layer is None else mg_layer.mlp hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', - to_mcore) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) return hf_state_dict def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py index 92a90d8..26dfb85 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py @@ -1,299 +1,38 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch -from contextlib import nullcontext -from megatron.core import parallel_state, tensor_parallel -from megatron.core.enums import Fp8Recipe -from megatron.core.fp8_utils import get_fp8_context -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.gpt import gpt_model -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor -from typing import List, Optional, Union +from megatron.core import parallel_state from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.utils import split_cp_inputs from ..constant import ModelType +from ..modules import CustomTransformerBlock from ..register import ModelLoader, ModelMeta, register_model from .utils import HuggingFaceVit -te_checkpoint = None -try: - import transformer_engine.pytorch as te # pylint: disable=unused-import +class Qwen3VLTransformerBlock(CustomTransformerBlock): - HAVE_TE = True -except ImportError: - HAVE_TE = False + def _layer_forward(self, layer, hidden_states, **kwargs): + deepstack_visual_embeds = kwargs.pop('deepstack_visual_embeds', None) + visual_pos_masks = kwargs.pop('visual_pos_masks', None) + hidden_states, context = super()._layer_forward(layer, hidden_states, **kwargs) + layer_number = layer.layer_number - 1 + if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_number], + ) + return hidden_states, context -if HAVE_TE: - from megatron.core.extensions.transformer_engine import te_checkpoint - - -class Qwen3VLTransformerBlock(gpt_model.TransformerBlock): - # Code borrowed from NVIDIA/Megatron-LM - - def _checkpointed_forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - context: torch.Tensor, - context_mask: torch.Tensor, - rotary_pos_emb: torch.Tensor, - attention_bias: torch.Tensor, - packed_seq_params: PackedSeqParams, - use_inner_fp8_context: bool, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds): - for index in range(start, end): - layer = self._get_layer(index) - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - inference_context=None, - packed_seq_params=packed_seq_params, - ) - # add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - return hidden_states, context - - return custom_forward - - def checkpoint_handler(forward_func): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: - return te_checkpoint( - forward_func, - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - else: - return tensor_parallel.checkpoint( - forward_func, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - - if self.config.recompute_method == 'uniform': - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - layer_idx = 0 - while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states, context = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers)) - - layer_idx += self.config.recompute_num_layers - - elif self.config.recompute_method == 'block': - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - recompute_skip_num_layers = 0 - for layer_idx in range(self.num_layers_per_pipeline_rank): - # Skip recomputation when input grad computation is not needed. - # Need to have at least one input tensor with gradient computation - # for re-enterant autograd engine. - if self.config.fp8 and not hidden_states.requires_grad: - recompute_skip_num_layers += 1 - if (layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) - else: - hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, - context_mask, rotary_pos_emb, - visual_pos_masks, deepstack_visual_embeds) - else: - raise ValueError('Invalid activation recompute method.') - - return hidden_states - - def forward( - self, - hidden_states: Union[torch.Tensor, WrappedTensor], - attention_mask: Optional[torch.Tensor], - context: Optional[torch.Tensor] = None, - context_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - rotary_pos_cos: Optional[torch.Tensor] = None, - rotary_pos_sin: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[torch.Tensor] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """ - Perform the forward pass through the transformer block. - This method handles the core computation of the transformer, including - self-attention, optional cross-attention, and feed-forward operations. - Args: - hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] - where s is the sequence length, b is the batch size, and h is the hidden size. - Can be passed as a WrappedTensor during inference to avoid an obsolete - reference in the calling function. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention. - context_mask (Tensor, optional): Mask for cross-attention context - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable - to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. - Used as an alternative to apply attention mask for TE cuDNN attention. - inference_context (BaseInferenceContext, optional): Parameters for inference-time - optimizations. - packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence - processing. - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ + def forward(self, *args, **kwargs): + deepstack_visual_embeds = kwargs.get('deepstack_visual_embeds') if deepstack_visual_embeds is not None: assert len(deepstack_visual_embeds) <= len( self.layers), (f'len(deepstack_visual_embeds): {len(deepstack_visual_embeds)}, ' f'len(self.layers): {len(self.layers)}.') - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Delete the obsolete reference to the initial input tensor if necessary - if isinstance(hidden_states, WrappedTensor): - hidden_states = hidden_states.unwrap() - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), - # otherwise do nothing extra at the outer level - # if we are using other fp8 recipes, then the context manager enter&exit are free - # we can wrap fp8_context within the for loop over layers, so that we can fine-grained - # control which layer will be fp8 or bf16 - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() - - with rng_context, outer_fp8_context: - # Forward pass. - if self.config.recompute_granularity == 'full' and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - use_inner_fp8_context=use_inner_fp8_context, - visual_pos_masks=visual_pos_masks, - deepstack_visual_embeds=deepstack_visual_embeds, - ) - else: - for l_no, layer in enumerate(self.layers): - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with self.offload_context, inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) - # add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - - if (torch.is_grad_enabled() and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) - - # Final layer norm. - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - # TENorm produces a "viewed" tensor. This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - # If this TransformerBlock is empty, input and output hidden states will be the same node - # on the computational graph and will lead to unexpected errors in pipeline schedules. - if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: - hidden_states = hidden_states.clone() - - return hidden_states + return super().forward(*args, **kwargs) def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor): @@ -427,16 +166,7 @@ def _get_inputs_embeds(self, inputs_embeds, inputs, visual, hf_config): class Qwen3VLLoader(ModelLoader): - - def _patch_transformer_block(self): - if hasattr(gpt_model, 'OriginTransformerBlock'): - return - gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock - gpt_model.TransformerBlock = Qwen3VLTransformerBlock - - def __init__(self, config): - super().__init__(config) - self._patch_transformer_block() + transformer_block = Qwen3VLTransformerBlock register_model( diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index eff1bd6..885b05b 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,4 +2,5 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_block import CustomTransformerBlock from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index ef150b7..5fcb167 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -96,10 +96,13 @@ def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', * submodules.in_proj = IdentityOp if 'cp_comm_type' not in inspect.signature(_GatedDeltaNet).parameters: kwargs.pop('cp_comm_type', None) - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.linear_decoupled_in_proj: + submodules.in_proj = in_proj if not config.linear_decoupled_in_proj: return - submodules.in_proj = in_proj self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 self.in_proj_ba_dim = self.num_value_heads * 2 del self.in_proj diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 537abf3..5398b71 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import torch import transformer_engine from contextlib import nullcontext @@ -29,11 +30,14 @@ def __init__(self, config: ModelConfig, submodules, *args, **kwargs): if config.fp8_param: eh_proj = submodules.eh_proj submodules.eh_proj = IdentityOp - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.fp8_param: + submodules.eh_proj = eh_proj self.tp_group = getattr(self, 'tp_group', None) if not config.fp8_param: return - submodules.eh_proj = eh_proj fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) with fp8_context: self.eh_proj = build_module( diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py new file mode 100644 index 0000000..0f7f735 --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -0,0 +1,402 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import torch +from contextlib import nullcontext +from megatron.core import tensor_parallel +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import te_checkpoint +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.utils import WrappedTensor, deprecate_inference_params, get_pg_rank, make_viewless_tensor +from typing import List, Optional, Set, Union, cast + +try: + from megatron.core.typed_torch import apply_module +except ImportError: + apply_module = None + + +# Code borrowed from NVIDIA/Megatron-LM +class CustomTransformerBlock(TransformerBlock): + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor, + context_mask: torch.Tensor, + rotary_pos_emb: torch.Tensor, + attention_bias: torch.Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + padding_mask: Optional[torch.Tensor] = None, + extract_layer_indices: Optional[Set[int]] = None, + layer_offset: int = 0, + **kwargs, + ): + """Forward method with activation checkpointing. + + Args: + extract_layer_indices (Set[int], optional): Global layer + indices (across all pipeline stages) from which to + extract features. + layer_offset (int): The global layer offset for the current + pipeline stage. Used to convert local layer indices to + global indices when checking extract_layer_indices. + + Returns: + If extract_layer_indices is empty: hidden_states tensor + If extract_layer_indices is non-empty: (hidden_states, intermediate_hidden_states) tuple + """ + if extract_layer_indices is None: + extract_layer_indices = set() + intermediate_hidden_states: List[torch.Tensor] = [] + + def custom(start: int, end: int): + + def custom_forward( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask=None, + **kwargs, + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context(self.config, layer.layer_number - 1) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context(self.config, layer.layer_number - 1) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context = self._layer_forward( + layer, + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + **kwargs, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask, + **kwargs, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask, + **kwargs, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + chunk_end = min(layer_idx + self.config.recompute_num_layers, self.num_layers_per_pipeline_rank) + hidden_states, context = checkpoint_handler(custom(layer_idx, chunk_end)) + + # Feature extraction for uniform recompute: collect at end of each chunk + # Note: Only the last layer of each chunk can have features collected + for idx in range(layer_idx, chunk_end): + if (idx + layer_offset) in extract_layer_indices: + # For uniform recompute, we can only get features at chunk boundaries + # Limitation: for fine-grained extraction, use 'block' + if idx == chunk_end - 1: + intermediate_hidden_states.append(hidden_states) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if (layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, + context_mask, rotary_pos_emb, **kwargs) + + # Feature extraction: collect hidden states at specified global layer indices + if (layer_idx + layer_offset) in extract_layer_indices: + intermediate_hidden_states.append(hidden_states) + else: + raise ValueError('Invalid activation recompute method.') + + # Return intermediate hidden states if feature extraction was requested + if len(extract_layer_indices) > 0: + return hidden_states, intermediate_hidden_states + + return hidden_states + + def _layer_forward(self, layer, hidden_states, **kwargs): + return layer(hidden_states=hidden_states, **kwargs) + + def forward( + self, + hidden_states: Union[torch.Tensor, WrappedTensor], + attention_mask: Optional[torch.Tensor], + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + rotary_pos_cos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + extract_layer_indices: Optional[Set[int]] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + **kwargs, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + extract_layer_indices (Set[int], optional): A set of global + layer indices (0-based across all pipeline stages) from + which to extract intermediate hidden states. If + non-empty, the forward pass will collect hidden_states + after each specified layer. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, List[Tensor]]]: + - If extract_layer_indices is None or empty: Returns the output hidden states tensor + of shape [s, b, h]. + - If extract_layer_indices is non-empty: Returns a tuple + of (hidden_states, intermediate_hidden_states) where + intermediate_hidden_states is a list of tensors + corresponding to hidden states after each layer in + extract_layer_indices. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Initialize feature collection (consistent with FastGen's Wan implementation) + if extract_layer_indices is None: + extract_layer_indices = set() + intermediate_hidden_states: List[torch.Tensor] = [] + + # Calculate the global layer offset for this pipeline stage + # This is needed to convert local layer indices to global indices for feature extraction + pp_group = self.pg_collection.pp if hasattr(self.pg_collection, 'pp') else None + layer_offset = get_transformer_layer_offset(self.config, self.vp_stage, get_pg_rank(pp_group)) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext()) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + checkpointed_result = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + padding_mask=padding_mask, + extract_layer_indices=extract_layer_indices, + layer_offset=layer_offset, + **kwargs, + ) + # Handle return value from _checkpointed_forward + if len(extract_layer_indices) > 0: + # (hidden_states, intermediate_hidden_states) tuple + hidden_states, intermediate_hidden_states = checkpointed_result + else: + # No intermediate_hidden_states requested: just hidden_states + hidden_states = checkpointed_result + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context(self.config, layer.layer_number - 1) + elif self.config.fp4: + inner_quantization_context = get_fp4_context(self.config, layer.layer_number - 1) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context = self._layer_forward( + layer, + hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + **kwargs) + + if (torch.is_grad_enabled() and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Extract intermediate embeddings using global layer index + if (l_no + layer_offset) in extract_layer_indices: + intermediate_hidden_states.append(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + if apply_module is None: + hidden_states = self.final_layernorm(hidden_states) + else: + hidden_states = apply_module(self.final_layernorm)(cast(torch.Tensor, hidden_states)) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + if len(extract_layer_indices) > 0: + return hidden_states, intermediate_hidden_states + + return hidden_states diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 024135e..a23cb0e 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,11 +1,15 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import enum import inspect import torch +from megatron.core.extensions.transformer_engine import TEFusedMLP from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, @@ -126,40 +130,13 @@ def __init__( hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) - # [Module 8: MLP block] - additional_mlp_kwargs = {} - # import here to avoid circular import - from megatron.core.extensions.transformer_engine import TEFusedMLP - from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP - from megatron.core.transformer.moe.moe_layer import MoELayer - - from mcore_bridge.model.gpts.glm4 import Glm4MLP - # MLP expects tp_group but MoELayer expects pg_collection to be passed in. - # We can change MLP to accept pg_collection but it makes the logic implicit - # The conditional below is to make the logic explicit - # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs - if isinstance(submodules.mlp, ModuleSpec): - if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): - additional_mlp_kwargs['pg_collection'] = pg_collection - # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: - additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module in (MLP, Glm4MLP): - assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: - assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - else: - logger.warning_once(f'Unknown MLP type: {submodules.mlp.module}. Using default kwargs.') - self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + # [Module 8: MLP block] + self.mlp = self._build_mlp(submodules.mlp) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) - # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) - self.is_moe_layer = isinstance(self.mlp, MoELayer) self.recompute_input_layernorm = False @@ -231,6 +208,32 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad + def _build_mlp(self, mlp_spec): + pg_collection = self.pg_collection + additional_mlp_kwargs = {} + # import here to avoid circular import + from mcore_bridge.model.gpts.glm4 import Glm4MLP + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(mlp_spec, ModuleSpec): + if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif mlp_spec.module in (MLP, Glm4MLP): + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') + return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. @@ -238,6 +241,10 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ + # Compatible with megatron-core 0.15 + for key in ['padding_mask']: + if kwargs.get(key) is None: + kwargs.pop(key, None) hidden_states, context = self._forward_attention(*args, **kwargs) # If padding_free is set, attention_mask does not exist. mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 15b37fe..57be2b6 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -1,9 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import megatron.core +from contextlib import contextmanager from dataclasses import dataclass from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +from megatron.core.models.gpt import gpt_model from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn @@ -13,7 +15,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import CustomTransformerLayer, MultiTokenPredictionLayer +from .modules import CustomTransformerBlock, CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -66,6 +68,7 @@ def get_model_meta(mcore_model_type: str) -> ModelMeta: class ModelLoader: model_cls = None + transformer_block = CustomTransformerBlock def __init__(self, config: ModelConfig): from mcore_bridge.model import GPTModel, MultimodalGPTModel @@ -131,17 +134,27 @@ def build_model( mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) - model = self.model_cls( - config=self.config, - transformer_layer_spec=transformer_layer_spec, - pre_process=pre_process, - post_process=post_process, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - ) + with self._patch_transformer_block(): + model = self.model_cls( + config=self.config, + transformer_layer_spec=transformer_layer_spec, + pre_process=pre_process, + post_process=post_process, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) self._set_linear_is_expert(model) return model + @contextmanager + def _patch_transformer_block(self): + TransformerBlock = gpt_model.TransformerBlock + gpt_model.TransformerBlock = self.transformer_block + try: + yield + finally: + gpt_model.TransformerBlock = TransformerBlock + def _set_linear_is_expert(self, model): for n, module in model.named_modules(): if '.local_experts.' in n and isinstance(module, (TELinear, TELayerNormColumnParallelLinear)) or isinstance(