Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ class GPTBridge:
hf_o_proj_key = 'o_proj'
hf_attn_prefix = 'self_attn'
hf_mlp_prefix = 'mlp'
hf_post_attention_layernorm = 'post_attention_layernorm'
hf_gate_key = 'gate.weight'
hf_shared_expert_key = None
hf_expert_bias_key = 'gate.e_score_correction_bias'
additional_dim0_keys = set()
additional_dim1_keys = set()

def __init__(self, config: ModelConfig):
self.config = config
Expand Down Expand Up @@ -124,11 +127,11 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
'linear_kv_up_proj',
# mtp
'eh_proj',
}
} | self.additional_dim0_keys
if self.config.task_type in {'causal_lm', 'generative_reranker'}:
dim0_keys.add('output_layer')
# RowLinear
dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'}
dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} | self.additional_dim1_keys
if 'lora_A' not in mg_key and 'lora_B' not in mg_key:
key, suffix = mg_key.rsplit('.', 2)[-2:]
if suffix == 'layer_norm_weight':
Expand Down Expand Up @@ -1587,13 +1590,13 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
hf_state_dict.update(
self._set_moe_state(
mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp))
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight',
to_mcore)
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict,
f'{self.hf_post_attention_layernorm}.weight', to_mcore)
else:
hf_state_dict.update(
self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict,
'post_attention_layernorm.weight', to_mcore)
f'{self.hf_post_attention_layernorm}.weight', to_mcore)
return hf_state_dict

def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
Expand Down
8 changes: 5 additions & 3 deletions src/mcore_bridge/model/gpts/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def __init__(
k_layernorm = submodules.k_layernorm
submodules.q_layernorm = IdentityOp
submodules.k_layernorm = IdentityOp
super().__init__(config, submodules, *args, **kwargs)
submodules.q_layernorm = q_layernorm
submodules.k_layernorm = k_layernorm
try:
super().__init__(config, submodules, *args, **kwargs)
finally:
submodules.q_layernorm = q_layernorm
submodules.k_layernorm = k_layernorm
self.q_norm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head * config.num_attention_heads,
Expand Down
4 changes: 2 additions & 2 deletions src/mcore_bridge/model/gpts/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
mg_mlp = None if mg_layer is None else mg_layer.mlp
hf_state_dict.update(
self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp))
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight',
to_mcore)
self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict,
f'{self.hf_post_attention_layernorm}.weight', to_mcore)
return hf_state_dict

def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict):
Expand Down
308 changes: 19 additions & 289 deletions src/mcore_bridge/model/mm_gpts/qwen3_vl.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/mcore_bridge/model/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .gated_delta_net import GatedDeltaNet
from .gated_self_attention import GatedSelfAttention
from .mtp_layer import MultiTokenPredictionLayer
from .transformer_block import CustomTransformerBlock
from .transformer_layer import CustomTransformerLayer
7 changes: 5 additions & 2 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', *
submodules.in_proj = IdentityOp
if 'cp_comm_type' not in inspect.signature(_GatedDeltaNet).parameters:
kwargs.pop('cp_comm_type', None)
super().__init__(config, submodules, *args, **kwargs)
try:
super().__init__(config, submodules, *args, **kwargs)
finally:
if config.linear_decoupled_in_proj:
submodules.in_proj = in_proj
if not config.linear_decoupled_in_proj:
return
submodules.in_proj = in_proj
self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2
self.in_proj_ba_dim = self.num_value_heads * 2
del self.in_proj
Expand Down
8 changes: 6 additions & 2 deletions src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import transformer_engine
from contextlib import nullcontext
Expand Down Expand Up @@ -29,11 +30,14 @@ def __init__(self, config: ModelConfig, submodules, *args, **kwargs):
if config.fp8_param:
eh_proj = submodules.eh_proj
submodules.eh_proj = IdentityOp
super().__init__(config, submodules, *args, **kwargs)
try:
super().__init__(config, submodules, *args, **kwargs)
finally:
if config.fp8_param:
submodules.eh_proj = eh_proj
self.tp_group = getattr(self, 'tp_group', None)
if not config.fp8_param:
return
submodules.eh_proj = eh_proj
fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False)
with fp8_context:
self.eh_proj = build_module(
Expand Down
Loading
Loading