Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
version: 2
updates:
- package-ecosystem: github-actions
Expand All @@ -16,4 +20,4 @@ updates:
groups:
python-dependencies:
patterns:
- "*"
- "*"
4 changes: 4 additions & 0 deletions .github/scripts/deps.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
common:
- transformers
4 changes: 4 additions & 0 deletions .github/scripts/install_deps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
import os
import subprocess
import sys
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
name: Release

concurrency:
Expand Down Expand Up @@ -118,4 +122,3 @@ jobs:
with:
name: ${{ env.WHL_NAME }}
path: dist/${{ env.WHL_NAME }}

5 changes: 4 additions & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
name: Unit Tests

defaults:
Expand Down Expand Up @@ -155,4 +159,3 @@ jobs:
run: |
mkdir -p artifacts
pytest --durations=0 tests/${{ matrix.test_script }}.py --junitxml=artifacts/${{ runner.os }}-${{ matrix.test_script }}.xml

5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -156,4 +160,3 @@ dmypy.json
cython_debug/

.idea/

11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Depending on the model family, Defuser can:

- patch a supported model class before load so HF instantiates a defused block directly
- split fused tensors such as `gate_up_proj` into `gate_proj` + `up_proj`
- convert 3D expert tensors into numbered expert `nn.Linear` modules
- convert 3D expert tensors, including registered expert buffers, into numbered expert `nn.Linear` modules
- preserve the original fused math while presenting a naive module structure again

Public API:
Expand All @@ -33,8 +33,9 @@ from defuser import convert_model, replace_fused_blocks
```

- `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction.
- `convert_model(model, cleanup_original=True, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints.
- `convert_model(model, cleanup_original=False, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints.
- Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported.
- Some model families appear in both support tables. Full models can be prepatched with `replace_fused_blocks(...)`, while standalone fused expert modules from those same families can still be runtime-defused with `convert_model(...)`.

`filter` is an optional list of PCRE regex rules evaluated against full module paths such as `model.layers.0.mlp.experts`:

Expand All @@ -46,7 +47,7 @@ from defuser import convert_model, replace_fused_blocks

## Supported Models

Defuser currently supports the following `transformers==5.3.0` `model_type` values.
Defuser currently supports the following `transformers>=5.3.0` `model_type` values.

### `replace_fused_blocks(model_type)` before load

Expand All @@ -65,7 +66,7 @@ Defuser currently supports the following `transformers==5.3.0` `model_type` valu

| Pattern | Supported model types | Defused op performed |
| --- | --- | --- |
| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. |
| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors or registered expert buffers into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. |
| Mixed sparse and shared experts | `deepseek_v3`, `glm_moe_dsa`, `qwen3_5_moe`, `qwen3_5_moe_text` | Runtime expert tensor defusion for routed experts while preserving the model's shared-expert path. |
| Transposed or packed expert tensors | `gpt_oss`, `phimoe` | Splits transposed fused expert `gate_up_proj` tensors into per-expert `gate_proj` + `up_proj`, preserves expert bias when present, and converts expert tensors into numbered expert `nn.Linear` modules. |
| Flattened expert layout | `dbrx` | Rebuilds the flattened DBRX expert FFN weights into numbered expert `gate_proj`, `up_proj`, and `down_proj` `nn.Linear` modules. |
Expand Down Expand Up @@ -100,6 +101,8 @@ converted = convert_model(model)
print(converted) # True when runtime defusion happened
```

`convert_model(model)` also preserves meta-device construction for supported meta-initialized models, so structural validation can run without materializing weights.

Use `filter` when only specific blocks should be defused:

```python
Expand Down
4 changes: 4 additions & 0 deletions defuser/checkpoint_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
import torch
from transformers.core_model_loading import Chunk, Concatenate, ConversionOps, MergeModulelist

Expand Down
24 changes: 21 additions & 3 deletions defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,26 @@ def get_checkpoint_conversion_mapping(model_type):


class PatchError(Exception):
"""Raised when Defuser cannot patch a registered Transformers class."""

pass


def _has_prebuilt_replacements(model: nn.Module, model_type: str) -> bool:
"""Detect models that were already instantiated with registry-backed replacements."""
replacement_paths = MODEL_CONFIG[model_type].get(PATCH.REPLACE_MODULE, [])
replacement_class_paths = {custom_path for _, custom_path in replacement_paths}
if not replacement_class_paths:
return False

for module in model.modules():
class_path = f"{module.__class__.__module__}.{module.__class__.__name__}"
if class_path in replacement_class_paths:
return True

return False


def replace_fused_blocks(model_type: str) -> bool:
"""Patch supported HF model classes so future loads instantiate defused blocks."""
if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger):
Expand Down Expand Up @@ -202,9 +219,10 @@ def convert_model(

apply_model_patches(model, max_layers=max_layers, filter_rules=filter)

# If fused blocks have already been structurally replaced at load model before,
# there is no need to perform runtime defusing again
if MODEL_CONFIG[model.config.model_type].get(PATCH.REPLACE_MODULE):
# Full models patched at construction time already contain the defused
# replacement modules, but standalone experts from those model types can
# still use runtime tensor defusion.
if _has_prebuilt_replacements(model, model.config.model_type):
return False

# Perform runtime defusing of fused projections
Expand Down
4 changes: 4 additions & 0 deletions defuser/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
4 changes: 4 additions & 0 deletions defuser/modeling/glm4v.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
from torch import nn
from transformers.activations import ACT2FN

Expand Down
89 changes: 60 additions & 29 deletions defuser/modeling/moe_experts_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,35 +292,66 @@ def _install_instance_forward(module: nn.Module, implementation: str) -> None:
def _detect_expert_projections(module: nn.Module) -> dict[str, dict]:
"""Detect which expert projections exist in the module.

This function scans the module for any 3D nn.Parameter attributes.
It first checks known projection names, then discovers any unknown 3D parameters.
This function scans the module for any registered 3D Parameter / Tensor
attributes. It first checks known projection names, then discovers any
unknown 3D registered tensors.

Returns:
Dict mapping projection names to their config, only for projections that exist
as 3D nn.Parameter in the module.
as 3D registered Parameter / Tensor in the module.
"""
detected = {}
# Only inspect registered tensors here. Scanning arbitrary attributes can
# trigger unrelated properties such as Transformers' `loss_function`.
local_tensors = {
name: tensor
for registry in (module._parameters, module._buffers)
for name, tensor in registry.items()
if isinstance(tensor, torch.Tensor)
}

# First, check known projection patterns
for proj_name, config in KNOWN_PROJECTION_PATTERNS.items():
param = getattr(module, proj_name, None)
if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3:
param = local_tensors.get(proj_name)
if param is not None and param.dim() == 3:
detected[proj_name] = config

# If no known patterns found, scan for any 3D Parameter (future-proofing)
# If no known patterns found, scan for any 3D registered tensor (future-proofing)
if not detected:
for attr_name in dir(module):
for attr_name, param in local_tensors.items():
if attr_name.startswith("_"):
continue
param = getattr(module, attr_name, None)
if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3:
if param is not None and param.dim() == 3:
# Use default config for unknown projections
if DEBUG_ON: logger.debug(f"Discovered unknown 3D projection: {attr_name}")
detected[attr_name] = {"is_input_proj": True, "output_multiplier": 1}

return detected


def _get_registered_tensor(module: nn.Module, name: str) -> torch.Tensor | None:
"""Return a direct registered parameter or buffer without touching properties."""

tensor = module._parameters.get(name)
if isinstance(tensor, torch.Tensor):
return tensor

tensor = module._buffers.get(name)
if isinstance(tensor, torch.Tensor):
return tensor

return None


def _set_registered_tensor_like(module: nn.Module, name: str, tensor: torch.Tensor, source: torch.Tensor) -> None:
"""Register ``tensor`` using the same parameter-vs-buffer kind as ``source``."""

if isinstance(source, nn.Parameter):
module.register_parameter(name, nn.Parameter(tensor, requires_grad=source.requires_grad))
else:
module.register_buffer(name, tensor)


def _experts_supports_decorator(module: nn.Module) -> bool:
"""Check if experts module supports @use_experts_implementation decorator.

Expand All @@ -334,11 +365,11 @@ def _experts_supports_decorator(module: nn.Module) -> bool:
return hasattr(forward_method, "__wrapped__")


def _infer_dimensions(param: nn.Parameter, config: dict, is_transposed: bool) -> tuple[int, int]:
def _infer_dimensions(param: torch.Tensor, config: dict, is_transposed: bool) -> tuple[int, int]:
"""Infer input and output dimensions for a projection.

Args:
param: The 3D parameter (num_experts, dim1, dim2)
param: The 3D projection tensor (num_experts, dim1, dim2)
config: Projection config with is_input_proj and output_multiplier
is_transposed: Whether weights are stored transposed

Expand Down Expand Up @@ -370,7 +401,7 @@ def _unfuse_single_projection(
dtype: torch.dtype,
target_device: torch.device,
) -> list | None:
"""Unfuse a single projection from 3D Parameter to a list of Linear layers.
"""Unfuse a single projection from a 3D registered tensor to Linear layers.

Optimized to keep peak device memory low while preserving the module's
original device placement:
Expand All @@ -390,8 +421,8 @@ def _unfuse_single_projection(
Returns:
List of Linear layers, or None if projection doesn't exist
"""
param = getattr(module, proj_name, None)
if param is None or not isinstance(param, nn.Parameter) or param.dim() != 3:
param = _get_registered_tensor(module, proj_name)
if param is None or param.dim() != 3:
return None

# Get projection config
Expand All @@ -402,17 +433,17 @@ def _unfuse_single_projection(

# Check for bias
bias_name = f"{proj_name}_bias"
bias_param = getattr(module, bias_name, None)
has_bias = bias_param is not None
bias_param = _get_registered_tensor(module, bias_name)
has_bias = isinstance(bias_param, torch.Tensor)

source_device = param.device
is_meta = source_device.type == "meta"
weight_requires_grad = param.requires_grad
weight_requires_grad = param.requires_grad if isinstance(param, nn.Parameter) else False

# Prepare weight slices on CPU in batch (single D2H transfer + batch transpose)
if not is_meta:
# Single transfer: GPU -> CPU (or no-op if already on CPU)
weights_cpu = param.data.cpu() # (num_experts, dim1, dim2)
weights_cpu = param.detach().cpu() # (num_experts, dim1, dim2)
if is_transposed:
# Batch transpose: (num_experts, in, out) -> (num_experts, out, in)
weights_cpu = weights_cpu.transpose(1, 2)
Expand All @@ -425,12 +456,12 @@ def _unfuse_single_projection(
weight_slices = weights_cpu.unbind(0)

if has_bias:
bias_cpu = bias_param.data.cpu()
bias_cpu = bias_param.detach().cpu()
# Ensure contiguous — bias may come from a chunk split (Phase 1)
if not bias_cpu.is_contiguous():
bias_cpu = bias_cpu.contiguous()
bias_slices = bias_cpu.unbind(0)
bias_requires_grad = bias_param.requires_grad
bias_requires_grad = bias_param.requires_grad if isinstance(bias_param, nn.Parameter) else False

# Drop the original fused parameter before allocating the defused
# per-expert linears back on the original device.
Expand Down Expand Up @@ -588,7 +619,7 @@ def _unfuse_experts_weights_inplace(

# Get first projection to determine num_experts and layout
first_proj_name = next(iter(detected_projections))
first_param = getattr(module, first_proj_name)
first_param = _get_registered_tensor(module, first_proj_name)
num_experts = first_param.shape[0]

# Detect if transposed
Expand All @@ -608,33 +639,33 @@ def _unfuse_experts_weights_inplace(
dtype = first_param.dtype
target_device = first_param.device if first_param.device.type != "meta" else "cpu"

# Phase 1: Split fused projections (e.g., gate_up_proj -> gate_proj + up_proj) into separate 3D params
# Phase 1: Split fused projections (e.g., gate_up_proj -> gate_proj + up_proj) into separate 3D tensors
extra_projections = {}
fused_to_remove = []
for proj_name, config in detected_projections.items():
split_into = config.get("split_into")
if not split_into:
continue
param = getattr(module, proj_name, None)
if param is None or not isinstance(param, nn.Parameter) or param.dim() != 3:
param = _get_registered_tensor(module, proj_name)
if param is None or param.dim() != 3:
continue
# Split along output dimension
split_dim = 2 if is_transposed else 1
split_params = param.chunk(len(split_into), dim=split_dim)

# Also split bias if present (e.g., gate_up_proj_bias -> gate_proj_bias + up_proj_bias)
bias_name = f"{proj_name}_bias"
bias_param = getattr(module, bias_name, None)
bias_param = _get_registered_tensor(module, bias_name)
bias_splits = None
if bias_param is not None and isinstance(bias_param, nn.Parameter) and bias_param.dim() == 2:
if isinstance(bias_param, torch.Tensor) and bias_param.dim() == 2:
bias_splits = bias_param.chunk(len(split_into), dim=1)

for i, (split_name, split_param) in enumerate(zip(split_into, split_params)):
# Avoid .contiguous() here — _unfuse_single_projection will handle it
# during batch transpose/unbind, saving a full-tensor copy
setattr(module, split_name, nn.Parameter(split_param))
_set_registered_tensor_like(module, split_name, split_param, param)
if bias_splits is not None:
setattr(module, f"{split_name}_bias", nn.Parameter(bias_splits[i]))
_set_registered_tensor_like(module, f"{split_name}_bias", bias_splits[i], bias_param)
extra_projections[split_name] = KNOWN_PROJECTION_PATTERNS.get(
split_name, {"is_input_proj": True, "output_multiplier": 1}
)
Expand All @@ -649,7 +680,7 @@ def _unfuse_experts_weights_inplace(
del detected_projections[name]
detected_projections.update(extra_projections)

# Phase 2: Unfuse all 3D params into per-expert Linear layers
# Phase 2: Unfuse all 3D tensors into per-expert Linear layers
proj_linears = {} # proj_name -> [Linear_expert0, Linear_expert1, ...]
for proj_name in detected_projections:
linears = _unfuse_single_projection(module, proj_name, num_experts, is_transposed, dtype, target_device)
Expand Down
4 changes: 4 additions & 0 deletions defuser/modeling/unfused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
4 changes: 4 additions & 0 deletions defuser/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
Loading