Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,46 @@ from defuser import convert_model, replace_fused_blocks
```

- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
- `convert_model(model, cleanup_original=True, max_layers=None)` converts an already loaded model in place. This is the runtime defusion path used for `qwen3_5_moe` style checkpoints.
- `convert_model(model, cleanup_original=True, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints.
- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported.

`filter` is an optional list of PCRE regex rules evaluated against full module paths such as `model.layers.0.mlp.experts`:

- `+:regex` explicitly includes matching candidate module paths
- `-:regex` explicitly excludes matching candidate module paths
- `regex` is shorthand for `+:regex`
- negative rules take priority over positive rules
- when `filter` is provided, a candidate module is defused only if it matches at least one positive rule and no negative rules

## Supported Models

| Model type | Recommended entrypoint | Defused op performed |
Defuser currently supports the following `transformers==5.3.0` `model_type` values.

### `replace_fused_blocks(model_type)` before load

| Model type | Defused op performed |
| --- | --- |
| `glm4_moe` | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
| `glm4v` | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
| `mixtral` | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
| `qwen2_moe` | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_moe` | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_next` | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_omni_moe` | Replaces both thinker and talker text sparse MoE blocks with defused per-expert linear blocks and applies small runtime compatibility patches for text `forward()` and `generate()`. |

### `convert_model(model)` after load

| Pattern | Supported model types | Defused op performed |
| --- | --- | --- |
| `mixtral` | `replace_fused_blocks("mixtral")` before load | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
| `qwen2_moe` | `replace_fused_blocks("qwen2_moe")` before load | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_moe` | `replace_fused_blocks("qwen3_moe")` before load | Replaces `Qwen3MoeSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_5_moe` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused `gate_up_proj` into `gate_proj` + `up_proj` and converts 3D expert tensors into numbered expert `nn.Linear` modules. |
| `qwen3_5_moe_text` | `convert_model(model)` after load | Same runtime expert tensor defusion path as `qwen3_5_moe`, applied to the text-only backbone. |
| `qwen3_next` | `replace_fused_blocks("qwen3_next")` before load | Replaces `Qwen3NextSparseMoeBlock` with a defused per-expert linear MoE block. |
| `qwen3_omni_moe` | `replace_fused_blocks("qwen3_omni_moe")` before load | Replaces the thinker text sparse MoE block with a defused per-expert linear block and applies small runtime compatibility patches for text `forward()` and `generate()`. |
| `glm4_moe` | `replace_fused_blocks("glm4_moe")` before load | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
| `glm4v` | `replace_fused_blocks("glm4v")` before load | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
| `gpt_oss` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, carries over expert biases, and converts fused expert tensors into numbered expert `nn.Linear` modules. |
| `llama4` | `convert_model(model)` after load | Runtime expert tensor defusion. Splits fused transposed expert `gate_up_proj` into per-expert `gate_proj` + `up_proj`, converts fused expert tensors into numbered expert `nn.Linear` modules, and preserves the llama4 batched expert-input execution contract. |
| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. |
| Mixed sparse and shared experts | `deepseek_v3`, `glm_moe_dsa`, `qwen3_5_moe`, `qwen3_5_moe_text` | Runtime expert tensor defusion for routed experts while preserving the model's shared-expert path. |
| Transposed or packed expert tensors | `gpt_oss`, `phimoe` | Splits transposed fused expert `gate_up_proj` tensors into per-expert `gate_proj` + `up_proj`, preserves expert bias when present, and converts expert tensors into numbered expert `nn.Linear` modules. |
| Flattened expert layout | `dbrx` | Rebuilds the flattened DBRX expert FFN weights into numbered expert `gate_proj`, `up_proj`, and `down_proj` `nn.Linear` modules. |
| Batched expert-input execution | `llama4` | Runtime expert tensor defusion plus preservation of the llama4 batched expert-input execution contract. |
| Non-gated expert MLPs | `nemotron_h` | Converts routed expert tensors into numbered `up_proj` and `down_proj` `nn.Linear` modules for non-gated experts. |
| Parallel expert blocks | `granitemoe`, `granitemoehybrid`, `granitemoeshared`, `jetmoe` | Converts packed expert weight tensors into numbered expert `linear` modules while keeping grouped expert execution intact. |
| Routed experts with identity experts | `longcat_flash` | Defuses routed experts into numbered `gate_proj`, `up_proj`, and `down_proj` modules and preserves zero or identity experts. |
| Fused dense `gate_up_proj` MLPs | `dia`, `glm`, `glm4`, `glm_image`, `glm_ocr`, `phi3`, `phi4_multimodal`, `zamba2` | Splits fused dense `gate_up_proj` layers into `gate_proj` + `up_proj` and updates the block `forward()` to preserve the original MLP math. |

## Workflow Summary

Expand All @@ -77,6 +99,20 @@ converted = convert_model(model)
print(converted) # True when runtime defusion happened
```

Use `filter` when only specific blocks should be defused:

```python
from defuser import convert_model

convert_model(
model,
filter=[
r"+:^model\.layers\.0\.mlp\.experts$",
r"-:^model\.layers\.0\.mlp\.experts\.shared_",
],
)
```

## Real Qwen3.5 MoE Example

The example below is written for the `transformers==5.3.0` public API surface and uses the real Hugging Face model `Qwen/Qwen3.5-35B-A3B-Instruct`. Defuser supports `transformers>=5.3.0`.
Expand Down
4 changes: 3 additions & 1 deletion defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def convert_model(
model: nn.Module,
cleanup_original: bool = False,
max_layers: int | None = None,
filter: list[str] | None = None,
) -> bool:
"""Convert one loaded model in place from fused experts to defused modules."""
if warn_if_public_api_transformers_unsupported("convert_model()", logger):
Expand Down Expand Up @@ -200,7 +201,7 @@ def convert_model(
if not check_model_compatibility(model):
return False

apply_model_patches(model)
apply_model_patches(model, max_layers=max_layers, filter_rules=filter)

# If fused blocks have already been structurally replaced at load model before,
# there is no need to perform runtime defusing again
Expand All @@ -214,6 +215,7 @@ def convert_model(
model,
cleanup_original=cleanup_original,
max_layers=max_layers,
filter_rules=filter,
)

return True
Expand Down
103 changes: 103 additions & 0 deletions defuser/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@ class PATCH(str, Enum):


MODEL_CONFIG = {
"dbrx": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"deepseek_v2": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"deepseek_v3": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"dia": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"dots1": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"ernie4_5_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"ernie4_5_vl_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"exaone_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"flex_olmo": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm4": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"mixtral": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
Expand Down Expand Up @@ -84,6 +117,10 @@ class PATCH(str, Enum):
(
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock",
"defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeThinkerTextSparseMoeBlock",
),
(
"transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeTalkerTextSparseMoeBlock",
"defuser.modeling.unfused_moe.qwen3_omni_moe.LinearQwen3OmniMoeTalkerTextSparseMoeBlock",
)
],
},
Expand All @@ -96,6 +133,9 @@ class PATCH(str, Enum):
)
],
},
"glm4_moe_lite": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm4v": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.REPLACE_MODULE: [
Expand All @@ -116,9 +156,39 @@ class PATCH(str, Enum):
),
],
},
"glm4v_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm_image": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm_moe_dsa": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"glm_ocr": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"gpt_oss": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"granitemoe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"granitemoehybrid": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"granitemoeshared": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"hunyuan_v1_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"jamba": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"jetmoe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"llama4": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
PATCH.EXPERTS_DEFUSE: [
Expand All @@ -128,7 +198,40 @@ class PATCH(str, Enum):
}
],
},
"lfm2_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"longcat_flash": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"minimax": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"minimax_m2": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"nemotron_h": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"olmoe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"phi3": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"phi4_multimodal": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"phimoe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"qwen3_vl_moe": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"solar_open": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
"zamba2": {
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
},
}
Loading