From 56471fcf5e5f4ab1cad04c2f14df6896d2574bda Mon Sep 17 00:00:00 2001 From: James Chapman Date: Wed, 8 Apr 2026 16:45:58 -0700 Subject: [PATCH] Add tracking for unused quantization rules. This change introduces a mechanism to track which quantization rules defined in a QuantizationConfig are actually applied during model quantization. A new method, `get_unused_rules()`, is added to the config to return any rules that did not match any operations in the model. This helps in verifying that all intended quantization rules are being used. The documentation is updated to reflect this new feature, and a test file is added to cover the functionality. PiperOrigin-RevId: 896769860 --- qwix/_src/providers/odml.py | 1 + qwix/_src/qconfig.py | 27 +++++++++ tests/_src/model_test.py | 1 + tests/_src/qconfig_test.py | 117 ++++++++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+) create mode 100644 tests/_src/qconfig_test.py diff --git a/qwix/_src/providers/odml.py b/qwix/_src/providers/odml.py index 756bf8e9..1e1e2080 100644 --- a/qwix/_src/providers/odml.py +++ b/qwix/_src/providers/odml.py @@ -183,6 +183,7 @@ def process_model_inputs( def process_model_output(self, method_name: str, model_output: Any) -> Any: """Quantize the output of the model.""" + self._initial_run_complete = True if method_name == '__call__': method_name = 'final' # backwards compatibility. # Quantize the model output if needed. diff --git a/qwix/_src/qconfig.py b/qwix/_src/qconfig.py index a0dee43f..909d80c5 100644 --- a/qwix/_src/qconfig.py +++ b/qwix/_src/qconfig.py @@ -113,8 +113,10 @@ def __init__( rules: The quantization rules in the order of precedence. disable_jit: Whether to disable JIT when wrapping methods. """ + self._rule_matches = [0] * len(rules) self._rules = [self._init_rule(rule) for rule in rules] self._logged_ops = set() + self._initial_run_complete = False self.disable_jit = disable_jit def _init_rule(self, rule: QuantizationRule) -> QuantizationRule: @@ -176,6 +178,7 @@ def process_model_inputs( def process_model_output(self, method_name: str, model_output: Any) -> Any: """Process the model output before it is returned.""" del method_name + self._initial_run_complete = True return model_output def _get_current_rule_and_op_id( @@ -208,6 +211,8 @@ def _get_current_rule_and_op_id( rule_idx = idx break rule = self._rules[rule_idx] if rule_idx is not None else None + if rule_idx is not None: + self._rule_matches[rule_idx] += 1 if only_rule: return rule, None @@ -228,3 +233,25 @@ def _get_current_rule_and_op_id( '[QWIX] module=%r op=%s rule=%s', module_path, op_id, rule_idx ) return rule, op_id + + def get_unused_rules(self) -> Sequence[QuantizationRule]: + """Returns the quantization rules that did not match any operations. + + This should be called after model quantization (e.g., `quantize_model`) to + verify that all rules were applied as expected. A rule is considered unused + if its `module_path` regex did not match any module's path, or if its + `op_names` did not match any intercepted operation within a matching module. + + Returns: + A sequence of unused quantization rules. + """ + if not self._initial_run_complete: + raise ValueError( + 'Quantization is not completed yet. Please call `quantize_model`' + ' before calling `get_unused_rules`.' + ) + return [ + self._rules[i] + for i, rule_matches in enumerate(self._rule_matches) + if rule_matches == 0 + ] diff --git a/tests/_src/model_test.py b/tests/_src/model_test.py index 2b886c18..ab2c900a 100644 --- a/tests/_src/model_test.py +++ b/tests/_src/model_test.py @@ -46,6 +46,7 @@ def get_intercept_map(self) -> Mapping[str, Callable[..., Any]]: return self._intercept_map def process_model_output(self, method_name: str, model_output: Any) -> Any: + self._initial_run_complete = True return model_output + 100 diff --git a/tests/_src/qconfig_test.py b/tests/_src/qconfig_test.py new file mode 100644 index 00000000..e774e8e5 --- /dev/null +++ b/tests/_src/qconfig_test.py @@ -0,0 +1,117 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from flax import nnx +from jax import numpy as jnp +from qwix._src import model as qwix_model +from qwix._src import qconfig +from qwix._src.core import qarray +from qwix._src.providers import ptq + + +class QconfigTest(absltest.TestCase): + + def setUp(self): + super().setUp() + dim: int = 16 + + class MyModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.lin1 = nnx.Linear(dim, dim, rngs=rngs) + self.lin2 = nnx.Linear(dim, dim, rngs=rngs) + self.layers = nnx.List( + [nnx.Linear(dim, dim, rngs=rngs) for _ in range(2)] + ) + + def __call__(self, x): + return self.lin1(x) + self.lin2(x) + sum(l(x) for l in self.layers) + + self.model = MyModel(rngs=nnx.Rngs(0)) + self.x = jnp.ones((1, dim)) + + def test_all_rules_used(self): + rules = [ + qconfig.QuantizationRule( + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + quant_model = qwix_model.quantize_model(self.model, provider, self.x) + + # Check unused rules. + self.assertEmpty(provider.get_unused_rules()) + + # Check that all layers are quantized. + self.assertIsInstance(quant_model.lin1.kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.lin2.kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray) + + def test_some_rules_unused(self): + rules = [ + qconfig.QuantizationRule( + module_path=r"layers/\d+", + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + qconfig.QuantizationRule( + module_path=r"LIN\d+", # Typo in module path. + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + quant_model = qwix_model.quantize_model(self.model, provider, self.x) + unused_rules = provider.get_unused_rules() + + # Check unused rules. + self.assertLen(unused_rules, 1) + self.assertEqual(unused_rules[0].module_path, rules[1].module_path) + + # Check that lin1 and lin2 are not quantized. + self.assertFalse(hasattr(quant_model.lin1.kernel, "array")) + self.assertFalse(hasattr(quant_model.lin2.kernel, "array")) + + # Check that layers are quantized. + self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray) + + def test_get_unused_rules_before_quantize_model(self): + rules = [ + qconfig.QuantizationRule( + module_path=r"layers/\d+", + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + with self.assertRaisesRegex( + ValueError, + "Quantization is not completed yet. Please call `quantize_model`" + " before calling `get_unused_rules`.", + ): + provider.get_unused_rules() + + qwix_model.quantize_model(self.model, provider, self.x) + self.assertEmpty(provider.get_unused_rules()) + + +if __name__ == "__main__": + absltest.main()