Skip to content

refactor transformer_block#72

Merged
Jintao-Huang merged 6 commits into
modelscope:mainfrom
Jintao-Huang:refactor_transformer_block
May 11, 2026
Merged

refactor transformer_block#72
Jintao-Huang merged 6 commits into
modelscope:mainfrom
Jintao-Huang:refactor_transformer_block

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomTransformerBlock to support activation checkpointing and feature extraction, refactors MLP construction in TransformerLayer, and implements a dynamic patching mechanism for transformer blocks. It also updates various model bridges and layers (Qwen3, Minimax, Gemma4) for better consistency and error handling. Feedback highlights critical logic errors in GPTBridge where bitwise intersection operators on key sets would result in empty sets, and potential AttributeError crashes in CustomTransformerBlock due to uninitialized pg_collection attributes. A misleading comment regarding argument handling in the transformer block's forward method also needs correction.

Comment thread src/mcore_bridge/model/modules/transformer_block.py
Comment thread src/mcore_bridge/model/modules/transformer_block.py
Comment thread src/mcore_bridge/bridge/gpt_bridge.py Outdated
Comment thread src/mcore_bridge/bridge/gpt_bridge.py Outdated
Comment thread src/mcore_bridge/model/modules/transformer_block.py
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

TODO: compat megatron-core 0.15

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a CustomTransformerBlock to centralize activation checkpointing, quantization support (FP8/FP4), and intermediate feature extraction, refactoring models like Qwen3-VL to utilize this shared logic. It also improves initialization safety across various modules using try...finally blocks and implements a dynamic patching mechanism in the ModelLoader to swap transformer block implementations. Feedback highlights critical issues in the new transformer block regarding the use of non-standard attributes for process groups and unreliable rank retrieval, recommending the use of parallel_state for better compatibility with Megatron-Core. A limitation regarding feature extraction during uniform recomputation was also identified.

# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
from contextlib import nullcontext
from megatron.core import tensor_parallel
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parallel_state module from megatron.core is required to correctly access process groups and ranks, as self.pg_collection is not a standard attribute of the TransformerBlock class in Megatron-Core.

Suggested change
from megatron.core import tensor_parallel
from megatron.core import parallel_state, tensor_parallel

forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
self.pg_collection.tp,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

self.pg_collection is not a standard attribute of TransformerBlock. Accessing it will raise an AttributeError. Use parallel_state.get_tensor_model_parallel_group() instead.

Suggested change
self.pg_collection.tp,
parallel_state.get_tensor_model_parallel_group(),

Comment on lines +266 to +267
pp_group = self.pg_collection.pp if hasattr(self.pg_collection, 'pp') else None
layer_offset = get_transformer_layer_offset(self.config, self.vp_stage, get_pg_rank(pp_group))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Retrieving the pipeline rank via get_pg_rank(pp_group) where pp_group might be None is unreliable and may return the global rank instead of the pipeline-parallel rank. Use parallel_state.get_pipeline_model_parallel_rank() for accuracy.

Suggested change
pp_group = self.pg_collection.pp if hasattr(self.pg_collection, 'pp') else None
layer_offset = get_transformer_layer_offset(self.config, self.vp_stage, get_pg_rank(pp_group))
layer_offset = get_transformer_layer_offset(self.config, self.vp_stage, parallel_state.get_pipeline_model_parallel_rank())

Comment on lines +141 to +148
# Feature extraction for uniform recompute: collect at end of each chunk
# Note: Only the last layer of each chunk can have features collected
for idx in range(layer_idx, chunk_end):
if (idx + layer_offset) in extract_layer_indices:
# For uniform recompute, we can only get features at chunk boundaries
# Limitation: for fine-grained extraction, use 'block'
if idx == chunk_end - 1:
intermediate_hidden_states.append(hidden_states)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for feature extraction during uniform recomputation is limited to chunk boundaries. If a requested layer index falls in the middle of a recompute chunk, it will be silently ignored. Consider adding a warning if extract_layer_indices contains layers that cannot be captured due to the current recompute_num_layers setting.

@Jintao-Huang Jintao-Huang merged commit 4d55427 into modelscope:main May 11, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants