diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 88dda8b..11a957a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium version: 2 updates: - package-ecosystem: github-actions @@ -16,4 +20,4 @@ updates: groups: python-dependencies: patterns: - - "*" \ No newline at end of file + - "*" diff --git a/.github/scripts/deps.yaml b/.github/scripts/deps.yaml index 4696490..788dad9 100644 --- a/.github/scripts/deps.yaml +++ b/.github/scripts/deps.yaml @@ -1,2 +1,6 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium common: - transformers diff --git a/.github/scripts/install_deps.py b/.github/scripts/install_deps.py index 1f3c2f4..50dad0a 100644 --- a/.github/scripts/install_deps.py +++ b/.github/scripts/install_deps.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium import os import subprocess import sys diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a1037fe..907482d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium name: Release concurrency: @@ -118,4 +122,3 @@ jobs: with: name: ${{ env.WHL_NAME }} path: dist/${{ env.WHL_NAME }} - diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 37e0fd6..705cc85 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium name: Unit Tests defaults: @@ -155,4 +159,3 @@ jobs: run: | mkdir -p artifacts pytest --durations=0 tests/${{ matrix.test_script }}.py --junitxml=artifacts/${{ runner.os }}-${{ matrix.test_script }}.xml - diff --git a/.gitignore b/.gitignore index d5c293f..8b34b54 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium ### Python template # Byte-compiled / optimized / DLL files __pycache__/ @@ -156,4 +160,3 @@ dmypy.json cython_debug/ .idea/ - diff --git a/README.md b/README.md index dd529b7..a37f60b 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Depending on the model family, Defuser can: - patch a supported model class before load so HF instantiates a defused block directly - split fused tensors such as `gate_up_proj` into `gate_proj` + `up_proj` -- convert 3D expert tensors into numbered expert `nn.Linear` modules +- convert 3D expert tensors, including registered expert buffers, into numbered expert `nn.Linear` modules - preserve the original fused math while presenting a naive module structure again Public API: @@ -33,8 +33,9 @@ from defuser import convert_model, replace_fused_blocks ``` - `replace_fused_blocks(model_type)` patches supported HF model classes before `from_pretrained()` or direct model construction. -- `convert_model(model, cleanup_original=True, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints. +- `convert_model(model, cleanup_original=False, max_layers=None, filter=None)` converts an already loaded model in place. This is the runtime defusion path for supported post-load expert and MLP conversions, including `qwen3_5_moe` style checkpoints. - Defuser is designed and CI-tested for `transformers>=5.3.0`, and support is only offered for that version range. Older versions log a warning on these public APIs and are skipped as unsupported. +- Some model families appear in both support tables. Full models can be prepatched with `replace_fused_blocks(...)`, while standalone fused expert modules from those same families can still be runtime-defused with `convert_model(...)`. `filter` is an optional list of PCRE regex rules evaluated against full module paths such as `model.layers.0.mlp.experts`: @@ -46,7 +47,7 @@ from defuser import convert_model, replace_fused_blocks ## Supported Models -Defuser currently supports the following `transformers==5.3.0` `model_type` values. +Defuser currently supports the following `transformers>=5.3.0` `model_type` values. ### `replace_fused_blocks(model_type)` before load @@ -65,7 +66,7 @@ Defuser currently supports the following `transformers==5.3.0` `model_type` valu | Pattern | Supported model types | Defused op performed | | --- | --- | --- | -| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. | +| Standard routed expert tensors | `deepseek_v2`, `dots1`, `ernie4_5_moe`, `ernie4_5_vl_moe`, `exaone_moe`, `flex_olmo`, `glm4_moe_lite`, `glm4v_moe`, `hunyuan_v1_moe`, `jamba`, `lfm2_moe`, `minimax`, `minimax_m2`, `olmoe`, `qwen3_vl_moe`, `solar_open` | Splits fused expert tensors or registered expert buffers into numbered expert `nn.Linear` modules with per-expert `gate_proj`, `up_proj`, and `down_proj`. | | Mixed sparse and shared experts | `deepseek_v3`, `glm_moe_dsa`, `qwen3_5_moe`, `qwen3_5_moe_text` | Runtime expert tensor defusion for routed experts while preserving the model's shared-expert path. | | Transposed or packed expert tensors | `gpt_oss`, `phimoe` | Splits transposed fused expert `gate_up_proj` tensors into per-expert `gate_proj` + `up_proj`, preserves expert bias when present, and converts expert tensors into numbered expert `nn.Linear` modules. | | Flattened expert layout | `dbrx` | Rebuilds the flattened DBRX expert FFN weights into numbered expert `gate_proj`, `up_proj`, and `down_proj` `nn.Linear` modules. | @@ -100,6 +101,8 @@ converted = convert_model(model) print(converted) # True when runtime defusion happened ``` +`convert_model(model)` also preserves meta-device construction for supported meta-initialized models, so structural validation can run without materializing weights. + Use `filter` when only specific blocks should be defused: ```python diff --git a/defuser/checkpoint_ops.py b/defuser/checkpoint_ops.py index 9b5d32d..e90793c 100644 --- a/defuser/checkpoint_ops.py +++ b/defuser/checkpoint_ops.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium import torch from transformers.core_model_loading import Chunk, Concatenate, ConversionOps, MergeModulelist diff --git a/defuser/defuser.py b/defuser/defuser.py index 8126a16..75c2d67 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -36,9 +36,26 @@ def get_checkpoint_conversion_mapping(model_type): class PatchError(Exception): + """Raised when Defuser cannot patch a registered Transformers class.""" + pass +def _has_prebuilt_replacements(model: nn.Module, model_type: str) -> bool: + """Detect models that were already instantiated with registry-backed replacements.""" + replacement_paths = MODEL_CONFIG[model_type].get(PATCH.REPLACE_MODULE, []) + replacement_class_paths = {custom_path for _, custom_path in replacement_paths} + if not replacement_class_paths: + return False + + for module in model.modules(): + class_path = f"{module.__class__.__module__}.{module.__class__.__name__}" + if class_path in replacement_class_paths: + return True + + return False + + def replace_fused_blocks(model_type: str) -> bool: """Patch supported HF model classes so future loads instantiate defused blocks.""" if warn_if_public_api_transformers_unsupported("replace_fused_blocks()", logger): @@ -202,9 +219,10 @@ def convert_model( apply_model_patches(model, max_layers=max_layers, filter_rules=filter) - # If fused blocks have already been structurally replaced at load model before, - # there is no need to perform runtime defusing again - if MODEL_CONFIG[model.config.model_type].get(PATCH.REPLACE_MODULE): + # Full models patched at construction time already contain the defused + # replacement modules, but standalone experts from those model types can + # still use runtime tensor defusion. + if _has_prebuilt_replacements(model, model.config.model_type): return False # Perform runtime defusing of fused projections diff --git a/defuser/modeling/__init__.py b/defuser/modeling/__init__.py index e69de29..9a20d2b 100644 --- a/defuser/modeling/__init__.py +++ b/defuser/modeling/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium diff --git a/defuser/modeling/glm4v.py b/defuser/modeling/glm4v.py index 462b0d4..62825b1 100644 --- a/defuser/modeling/glm4v.py +++ b/defuser/modeling/glm4v.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium from torch import nn from transformers.activations import ACT2FN diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py index 8f48a51..ad5b35a 100644 --- a/defuser/modeling/moe_experts_interface.py +++ b/defuser/modeling/moe_experts_interface.py @@ -292,28 +292,36 @@ def _install_instance_forward(module: nn.Module, implementation: str) -> None: def _detect_expert_projections(module: nn.Module) -> dict[str, dict]: """Detect which expert projections exist in the module. - This function scans the module for any 3D nn.Parameter attributes. - It first checks known projection names, then discovers any unknown 3D parameters. + This function scans the module for any registered 3D Parameter / Tensor + attributes. It first checks known projection names, then discovers any + unknown 3D registered tensors. Returns: Dict mapping projection names to their config, only for projections that exist - as 3D nn.Parameter in the module. + as 3D registered Parameter / Tensor in the module. """ detected = {} + # Only inspect registered tensors here. Scanning arbitrary attributes can + # trigger unrelated properties such as Transformers' `loss_function`. + local_tensors = { + name: tensor + for registry in (module._parameters, module._buffers) + for name, tensor in registry.items() + if isinstance(tensor, torch.Tensor) + } # First, check known projection patterns for proj_name, config in KNOWN_PROJECTION_PATTERNS.items(): - param = getattr(module, proj_name, None) - if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3: + param = local_tensors.get(proj_name) + if param is not None and param.dim() == 3: detected[proj_name] = config - # If no known patterns found, scan for any 3D Parameter (future-proofing) + # If no known patterns found, scan for any 3D registered tensor (future-proofing) if not detected: - for attr_name in dir(module): + for attr_name, param in local_tensors.items(): if attr_name.startswith("_"): continue - param = getattr(module, attr_name, None) - if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3: + if param is not None and param.dim() == 3: # Use default config for unknown projections if DEBUG_ON: logger.debug(f"Discovered unknown 3D projection: {attr_name}") detected[attr_name] = {"is_input_proj": True, "output_multiplier": 1} @@ -321,6 +329,29 @@ def _detect_expert_projections(module: nn.Module) -> dict[str, dict]: return detected +def _get_registered_tensor(module: nn.Module, name: str) -> torch.Tensor | None: + """Return a direct registered parameter or buffer without touching properties.""" + + tensor = module._parameters.get(name) + if isinstance(tensor, torch.Tensor): + return tensor + + tensor = module._buffers.get(name) + if isinstance(tensor, torch.Tensor): + return tensor + + return None + + +def _set_registered_tensor_like(module: nn.Module, name: str, tensor: torch.Tensor, source: torch.Tensor) -> None: + """Register ``tensor`` using the same parameter-vs-buffer kind as ``source``.""" + + if isinstance(source, nn.Parameter): + module.register_parameter(name, nn.Parameter(tensor, requires_grad=source.requires_grad)) + else: + module.register_buffer(name, tensor) + + def _experts_supports_decorator(module: nn.Module) -> bool: """Check if experts module supports @use_experts_implementation decorator. @@ -334,11 +365,11 @@ def _experts_supports_decorator(module: nn.Module) -> bool: return hasattr(forward_method, "__wrapped__") -def _infer_dimensions(param: nn.Parameter, config: dict, is_transposed: bool) -> tuple[int, int]: +def _infer_dimensions(param: torch.Tensor, config: dict, is_transposed: bool) -> tuple[int, int]: """Infer input and output dimensions for a projection. Args: - param: The 3D parameter (num_experts, dim1, dim2) + param: The 3D projection tensor (num_experts, dim1, dim2) config: Projection config with is_input_proj and output_multiplier is_transposed: Whether weights are stored transposed @@ -370,7 +401,7 @@ def _unfuse_single_projection( dtype: torch.dtype, target_device: torch.device, ) -> list | None: - """Unfuse a single projection from 3D Parameter to a list of Linear layers. + """Unfuse a single projection from a 3D registered tensor to Linear layers. Optimized to keep peak device memory low while preserving the module's original device placement: @@ -390,8 +421,8 @@ def _unfuse_single_projection( Returns: List of Linear layers, or None if projection doesn't exist """ - param = getattr(module, proj_name, None) - if param is None or not isinstance(param, nn.Parameter) or param.dim() != 3: + param = _get_registered_tensor(module, proj_name) + if param is None or param.dim() != 3: return None # Get projection config @@ -402,17 +433,17 @@ def _unfuse_single_projection( # Check for bias bias_name = f"{proj_name}_bias" - bias_param = getattr(module, bias_name, None) - has_bias = bias_param is not None + bias_param = _get_registered_tensor(module, bias_name) + has_bias = isinstance(bias_param, torch.Tensor) source_device = param.device is_meta = source_device.type == "meta" - weight_requires_grad = param.requires_grad + weight_requires_grad = param.requires_grad if isinstance(param, nn.Parameter) else False # Prepare weight slices on CPU in batch (single D2H transfer + batch transpose) if not is_meta: # Single transfer: GPU -> CPU (or no-op if already on CPU) - weights_cpu = param.data.cpu() # (num_experts, dim1, dim2) + weights_cpu = param.detach().cpu() # (num_experts, dim1, dim2) if is_transposed: # Batch transpose: (num_experts, in, out) -> (num_experts, out, in) weights_cpu = weights_cpu.transpose(1, 2) @@ -425,12 +456,12 @@ def _unfuse_single_projection( weight_slices = weights_cpu.unbind(0) if has_bias: - bias_cpu = bias_param.data.cpu() + bias_cpu = bias_param.detach().cpu() # Ensure contiguous — bias may come from a chunk split (Phase 1) if not bias_cpu.is_contiguous(): bias_cpu = bias_cpu.contiguous() bias_slices = bias_cpu.unbind(0) - bias_requires_grad = bias_param.requires_grad + bias_requires_grad = bias_param.requires_grad if isinstance(bias_param, nn.Parameter) else False # Drop the original fused parameter before allocating the defused # per-expert linears back on the original device. @@ -588,7 +619,7 @@ def _unfuse_experts_weights_inplace( # Get first projection to determine num_experts and layout first_proj_name = next(iter(detected_projections)) - first_param = getattr(module, first_proj_name) + first_param = _get_registered_tensor(module, first_proj_name) num_experts = first_param.shape[0] # Detect if transposed @@ -608,15 +639,15 @@ def _unfuse_experts_weights_inplace( dtype = first_param.dtype target_device = first_param.device if first_param.device.type != "meta" else "cpu" - # Phase 1: Split fused projections (e.g., gate_up_proj -> gate_proj + up_proj) into separate 3D params + # Phase 1: Split fused projections (e.g., gate_up_proj -> gate_proj + up_proj) into separate 3D tensors extra_projections = {} fused_to_remove = [] for proj_name, config in detected_projections.items(): split_into = config.get("split_into") if not split_into: continue - param = getattr(module, proj_name, None) - if param is None or not isinstance(param, nn.Parameter) or param.dim() != 3: + param = _get_registered_tensor(module, proj_name) + if param is None or param.dim() != 3: continue # Split along output dimension split_dim = 2 if is_transposed else 1 @@ -624,17 +655,17 @@ def _unfuse_experts_weights_inplace( # Also split bias if present (e.g., gate_up_proj_bias -> gate_proj_bias + up_proj_bias) bias_name = f"{proj_name}_bias" - bias_param = getattr(module, bias_name, None) + bias_param = _get_registered_tensor(module, bias_name) bias_splits = None - if bias_param is not None and isinstance(bias_param, nn.Parameter) and bias_param.dim() == 2: + if isinstance(bias_param, torch.Tensor) and bias_param.dim() == 2: bias_splits = bias_param.chunk(len(split_into), dim=1) for i, (split_name, split_param) in enumerate(zip(split_into, split_params)): # Avoid .contiguous() here — _unfuse_single_projection will handle it # during batch transpose/unbind, saving a full-tensor copy - setattr(module, split_name, nn.Parameter(split_param)) + _set_registered_tensor_like(module, split_name, split_param, param) if bias_splits is not None: - setattr(module, f"{split_name}_bias", nn.Parameter(bias_splits[i])) + _set_registered_tensor_like(module, f"{split_name}_bias", bias_splits[i], bias_param) extra_projections[split_name] = KNOWN_PROJECTION_PATTERNS.get( split_name, {"is_input_proj": True, "output_multiplier": 1} ) @@ -649,7 +680,7 @@ def _unfuse_experts_weights_inplace( del detected_projections[name] detected_projections.update(extra_projections) - # Phase 2: Unfuse all 3D params into per-expert Linear layers + # Phase 2: Unfuse all 3D tensors into per-expert Linear layers proj_linears = {} # proj_name -> [Linear_expert0, Linear_expert1, ...] for proj_name in detected_projections: linears = _unfuse_single_projection(module, proj_name, num_experts, is_transposed, dtype, target_device) diff --git a/defuser/modeling/unfused_moe/__init__.py b/defuser/modeling/unfused_moe/__init__.py index e69de29..9a20d2b 100644 --- a/defuser/modeling/unfused_moe/__init__.py +++ b/defuser/modeling/unfused_moe/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium diff --git a/defuser/utils/__init__.py b/defuser/utils/__init__.py index e69de29..9a20d2b 100644 --- a/defuser/utils/__init__.py +++ b/defuser/utils/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium diff --git a/pyproject.toml b/pyproject.toml index 4a7b140..530327c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.18" +version = "0.0.19" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9" @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: 3.15", "Programming Language :: C", "Topic :: Software Development :: Libraries", "Topic :: Text Processing :: Linguistic" diff --git a/requirements.txt b/requirements.txt index a6f6563..9311d03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium pypcre>=0.2.13 transformers logbar>=0.4.1 diff --git a/setup.py b/setup.py index 6068493..8c65856 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium from setuptools import setup setup() diff --git a/tests/test_candidate_coverage.py b/tests/test_candidate_coverage.py index a2045db..bec021c 100644 --- a/tests/test_candidate_coverage.py +++ b/tests/test_candidate_coverage.py @@ -17,12 +17,16 @@ class _DummyLayer(nn.Module): + """Wrap one candidate module so convert_model sees a layer-shaped path.""" + def __init__(self, module: nn.Module): super().__init__() self.module = module class _DummyModel(nn.Module): + """Minimal model shell for exercising runtime conversion helpers.""" + def __init__(self, model_type: str, module: nn.Module): super().__init__() self.config = getattr(module, "config", None) @@ -33,10 +37,21 @@ def __init__(self, model_type: str, module: nn.Module): def _load(module_path: str, attr_name: str): + """Import one symbol lazily so candidate tables stay compact.""" return getattr(import_module(module_path), attr_name) +def _normalize_attr_value(obj, attr: str, value): + """Preserve list-valued config shapes when tests shrink scalar attributes.""" + current = getattr(obj, attr) + if isinstance(current, list) and not isinstance(value, list): + width = len(current) or 1 + return [value for _ in range(width)] + return value + + def _build_config(case: dict): + """Build a tiny but valid config for one candidate module case.""" config = _load(case["config_module"], case["config_name"])() sub_attr = case.get("sub_attr") if sub_attr is not None: @@ -73,24 +88,29 @@ def _build_config(case: dict): "dropout_rate": 0.0, }.items(): if hasattr(config, attr): - setattr(config, attr, value) + setattr(config, attr, _normalize_attr_value(config, attr, value)) for attr, value in case.get("config_updates", {}).items(): if hasattr(config, attr): - setattr(config, attr, value) + setattr(config, attr, _normalize_attr_value(config, attr, value)) return config def _build_module(case: dict) -> nn.Module: + """Instantiate one candidate module in the format its constructor expects.""" module_cls = _load(case["module_path"], case["class_name"]) kind = case.get("kind", "config") if kind == "parallel": return module_cls(case["num_experts"], case["input_size"], case["output_size"]).eval() if kind == "zamba2": return module_cls(_build_config(case), num_fwd_mem_blocks=1, block_id=0).eval() - return module_cls(_build_config(case)).eval() + config = _build_config(case) + if case["model_type"] == "ernie4_5_vl_moe" and isinstance(getattr(config, "moe_intermediate_size", None), list): + return module_cls(config, intermediate_size=config.moe_intermediate_size[0]).eval() + return module_cls(config).eval() def _wrapped_model(case: dict) -> tuple[_DummyModel, nn.Module]: + """Seed one module deterministically and wrap it as a dummy model.""" module = _build_module(case) generator = torch.Generator(device="cpu").manual_seed(0) with torch.no_grad(): @@ -101,10 +121,12 @@ def _wrapped_model(case: dict) -> tuple[_DummyModel, nn.Module]: def _patched_module(model: _DummyModel) -> nn.Module: + """Return the converted test module from the dummy model wrapper.""" return model.layers[0].module def _assert_expert_container(module: nn.Module, attrs: tuple[str, ...]) -> None: + """Verify that expert ``0`` exposes the expected defused projections.""" assert hasattr(module, "0") expert0 = getattr(module, "0") for attr in attrs: @@ -112,6 +134,7 @@ def _assert_expert_container(module: nn.Module, attrs: tuple[str, ...]) -> None: def _standard_hidden(case: dict) -> torch.Tensor: + """Generate default hidden states for sparse-MoE candidate coverage tests.""" return torch.randn(case.get("hidden_shape", (5, case["input_dim"])), dtype=torch.float32) @@ -553,11 +576,13 @@ def _standard_hidden(case: dict) -> torch.Tensor: def test_model_registry_covers_all_scanned_candidates(): + """Every scanned candidate model type should be represented in the registry.""" assert ALL_CANDIDATE_MODEL_TYPES.issubset(MODEL_CONFIG) @pytest.mark.parametrize("case", STANDARD_MOE_CASES, ids=[case["model_type"] for case in STANDARD_MOE_CASES]) def test_standard_moe_candidates_convert_and_preserve_forward(case): + """Standard sparse-MoE candidates should convert without changing outputs.""" torch.manual_seed(0) model, original_module = _wrapped_model(case) hidden_states = _standard_hidden(case) @@ -593,6 +618,7 @@ def test_standard_moe_candidates_convert_and_preserve_forward(case): @pytest.mark.parametrize("case", PARALLEL_CASES, ids=[case["model_type"] for case in PARALLEL_CASES]) def test_parallel_expert_candidates_convert_and_preserve_forward(case): + """Parallel expert containers should convert without changing outputs.""" torch.manual_seed(0) model, original_module = _wrapped_model(case) expert_size = [2, 1, 0, 3] @@ -616,6 +642,7 @@ def test_parallel_expert_candidates_convert_and_preserve_forward(case): @pytest.mark.parametrize("case", DENSE_CASES, ids=[case["label"] for case in DENSE_CASES]) def test_dense_candidates_convert_and_preserve_forward(case): + """Dense split-gate MLP candidates should preserve forward results after conversion.""" torch.manual_seed(0) model, original_module = _wrapped_model(case) hidden_states = torch.randn(3, case["hidden_size"], dtype=torch.float32) @@ -644,6 +671,7 @@ def test_dense_candidates_convert_and_preserve_forward(case): def test_runtime_model_patches_respect_max_layers(): + """Runtime-only patches should stop once ``max_layers`` is reached.""" module0 = _build_module(next(case for case in DENSE_CASES if case["label"] == "phi3")) module1 = _build_module(next(case for case in DENSE_CASES if case["label"] == "phi3")) diff --git a/tests/test_meta_model_defusion.py b/tests/test_meta_model_defusion.py index 2db5d2e..b8a29ef 100644 --- a/tests/test_meta_model_defusion.py +++ b/tests/test_meta_model_defusion.py @@ -15,16 +15,28 @@ def _load(module_path: str, attr_name: str): + """Import one symbol lazily so the model case table stays readable.""" return getattr(import_module(module_path), attr_name) +def _normalize_attr_value(obj, attr: str, value): + """Preserve list-valued config fields when shrinking configs for meta tests.""" + current = getattr(obj, attr) + if isinstance(current, list) and not isinstance(value, list): + width = len(current) or 1 + return [value for _ in range(width)] + return value + + def _set_if_has(obj, **kwargs) -> None: + """Apply overrides only to attributes exposed by the current config node.""" for attr, value in kwargs.items(): if hasattr(obj, attr): - setattr(obj, attr, value) + setattr(obj, attr, _normalize_attr_value(obj, attr, value)) def _mutate_common_config_tree(config, visited: set[int] | None = None) -> None: + """Shrink nested configs recursively so meta-model construction stays lightweight.""" if config is None or isinstance(config, (int, float, str, bool, list, tuple, dict)): return @@ -125,6 +137,7 @@ def _mutate_common_config_tree(config, visited: set[int] | None = None) -> None: def _build_model_config(case: dict): + """Construct a small config tree for one registered public model type.""" config = _load(case["config_module"], case["config_class"])() _mutate_common_config_tree(config) @@ -180,6 +193,7 @@ def _build_model_config(case: dict): def _find_module_hits(model, class_paths: tuple[str, ...]) -> list[tuple[str, str]]: + """Collect modules whose class path matches one of the expected targets.""" hits = [] wanted = set(class_paths) for name, module in model.named_modules(): @@ -190,16 +204,19 @@ def _find_module_hits(model, class_paths: tuple[str, ...]) -> list[tuple[str, st def _assert_meta_parameters(module) -> None: + """Assert that one module keeps all of its parameters on meta.""" for _, param in module.named_parameters(recurse=True): assert param.is_meta def _assert_all_model_parameters_meta(model) -> None: + """Assert that conversion does not materialize weights during meta tests.""" for _, param in model.named_parameters(): assert param.is_meta def _validate_defused_module(case: dict, module) -> None: + """Run the case-specific structural checks on one converted module.""" kind = case["validator"] if kind == "experts": @@ -729,11 +746,13 @@ def _validate_defused_module(case: dict, module) -> None: def test_meta_model_cases_cover_registered_public_models(): + """Meta-model coverage should stay aligned with the public registry.""" assert {case["model_type"] for case in META_MODEL_CASES} == set(MODEL_CONFIG) - {"qwen3_5_moe_text"} @pytest.mark.parametrize("case", META_MODEL_CASES, ids=[case["model_type"] for case in META_MODEL_CASES]) def test_each_model_defuses_direct_meta_model(case): + """Each registered public model should expose the expected defused modules on meta.""" if case["mode"] == "replace": replace_fused_blocks(case["model_type"]) diff --git a/tests/test_moe_experts_interface.py b/tests/test_moe_experts_interface.py new file mode 100644 index 0000000..7b286ac --- /dev/null +++ b/tests/test_moe_experts_interface.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import torch +import torch.nn as nn + +from defuser.modeling.moe_experts_interface import _detect_expert_projections, _unfuse_experts_weights_inplace + + +class _UnknownExpertsWithLossProperty(nn.Module): + """Regression fixture for expert detection on modules with unrelated properties.""" + + def __init__(self) -> None: + super().__init__() + self.expert_weight = nn.Parameter(torch.randn(4, 8, 16)) + self.loss_property_accesses = 0 + + @property + def loss_function(self): + self.loss_property_accesses += 1 + raise AssertionError("expert detection should not touch unrelated properties") + + +def test_detect_expert_projections_ignores_unrelated_properties() -> None: + """Projection detection should ignore unrelated properties on expert fixtures.""" + module = _UnknownExpertsWithLossProperty() + + detected = _detect_expert_projections(module) + + assert detected == { + "expert_weight": {"is_input_proj": True, "output_multiplier": 1}, + } + assert module.loss_property_accesses == 0 + + +class _BufferBackedExperts(nn.Module): + """Exercise the buffer-backed fused expert path end to end.""" + + def __init__(self) -> None: + super().__init__() + self.register_buffer("gate_up_proj", torch.arange(2 * 6 * 4, dtype=torch.float32).reshape(2, 6, 4)) + self.register_buffer("down_proj", torch.arange(2 * 4 * 3, dtype=torch.float32).reshape(2, 4, 3)) + + +def test_unfuse_experts_supports_registered_buffers() -> None: + """Buffer-backed fused experts should unfuse into per-expert Linear layers.""" + module = _BufferBackedExperts() + expected_gate_proj = module.gate_up_proj[0, :3].clone() + expected_up_proj = module.gate_up_proj[0, 3:].clone() + expected_down_proj = module.down_proj[0].clone() + + changed = _unfuse_experts_weights_inplace(module, check_decorator=False) + + assert changed is True + expert0 = getattr(module, "0") + torch.testing.assert_close(expert0.gate_proj.weight, expected_gate_proj) + torch.testing.assert_close(expert0.up_proj.weight, expected_up_proj) + torch.testing.assert_close(expert0.down_proj.weight, expected_down_proj)