diff --git a/qwix/_src/providers/odml.py b/qwix/_src/providers/odml.py index 756bf8e..1e1e208 100644 --- a/qwix/_src/providers/odml.py +++ b/qwix/_src/providers/odml.py @@ -183,6 +183,7 @@ def process_model_inputs( def process_model_output(self, method_name: str, model_output: Any) -> Any: """Quantize the output of the model.""" + self._initial_run_complete = True if method_name == '__call__': method_name = 'final' # backwards compatibility. # Quantize the model output if needed. diff --git a/qwix/_src/qconfig.py b/qwix/_src/qconfig.py index a0dee43..909d80c 100644 --- a/qwix/_src/qconfig.py +++ b/qwix/_src/qconfig.py @@ -113,8 +113,10 @@ def __init__( rules: The quantization rules in the order of precedence. disable_jit: Whether to disable JIT when wrapping methods. """ + self._rule_matches = [0] * len(rules) self._rules = [self._init_rule(rule) for rule in rules] self._logged_ops = set() + self._initial_run_complete = False self.disable_jit = disable_jit def _init_rule(self, rule: QuantizationRule) -> QuantizationRule: @@ -176,6 +178,7 @@ def process_model_inputs( def process_model_output(self, method_name: str, model_output: Any) -> Any: """Process the model output before it is returned.""" del method_name + self._initial_run_complete = True return model_output def _get_current_rule_and_op_id( @@ -208,6 +211,8 @@ def _get_current_rule_and_op_id( rule_idx = idx break rule = self._rules[rule_idx] if rule_idx is not None else None + if rule_idx is not None: + self._rule_matches[rule_idx] += 1 if only_rule: return rule, None @@ -228,3 +233,25 @@ def _get_current_rule_and_op_id( '[QWIX] module=%r op=%s rule=%s', module_path, op_id, rule_idx ) return rule, op_id + + def get_unused_rules(self) -> Sequence[QuantizationRule]: + """Returns the quantization rules that did not match any operations. + + This should be called after model quantization (e.g., `quantize_model`) to + verify that all rules were applied as expected. A rule is considered unused + if its `module_path` regex did not match any module's path, or if its + `op_names` did not match any intercepted operation within a matching module. + + Returns: + A sequence of unused quantization rules. + """ + if not self._initial_run_complete: + raise ValueError( + 'Quantization is not completed yet. Please call `quantize_model`' + ' before calling `get_unused_rules`.' + ) + return [ + self._rules[i] + for i, rule_matches in enumerate(self._rule_matches) + if rule_matches == 0 + ] diff --git a/tests/_src/model_test.py b/tests/_src/model_test.py index 2b886c1..ab2c900 100644 --- a/tests/_src/model_test.py +++ b/tests/_src/model_test.py @@ -46,6 +46,7 @@ def get_intercept_map(self) -> Mapping[str, Callable[..., Any]]: return self._intercept_map def process_model_output(self, method_name: str, model_output: Any) -> Any: + self._initial_run_complete = True return model_output + 100 diff --git a/tests/_src/qconfig_test.py b/tests/_src/qconfig_test.py new file mode 100644 index 0000000..e774e8e --- /dev/null +++ b/tests/_src/qconfig_test.py @@ -0,0 +1,117 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +from flax import nnx +from jax import numpy as jnp +from qwix._src import model as qwix_model +from qwix._src import qconfig +from qwix._src.core import qarray +from qwix._src.providers import ptq + + +class QconfigTest(absltest.TestCase): + + def setUp(self): + super().setUp() + dim: int = 16 + + class MyModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.lin1 = nnx.Linear(dim, dim, rngs=rngs) + self.lin2 = nnx.Linear(dim, dim, rngs=rngs) + self.layers = nnx.List( + [nnx.Linear(dim, dim, rngs=rngs) for _ in range(2)] + ) + + def __call__(self, x): + return self.lin1(x) + self.lin2(x) + sum(l(x) for l in self.layers) + + self.model = MyModel(rngs=nnx.Rngs(0)) + self.x = jnp.ones((1, dim)) + + def test_all_rules_used(self): + rules = [ + qconfig.QuantizationRule( + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + quant_model = qwix_model.quantize_model(self.model, provider, self.x) + + # Check unused rules. + self.assertEmpty(provider.get_unused_rules()) + + # Check that all layers are quantized. + self.assertIsInstance(quant_model.lin1.kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.lin2.kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray) + + def test_some_rules_unused(self): + rules = [ + qconfig.QuantizationRule( + module_path=r"layers/\d+", + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + qconfig.QuantizationRule( + module_path=r"LIN\d+", # Typo in module path. + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + quant_model = qwix_model.quantize_model(self.model, provider, self.x) + unused_rules = provider.get_unused_rules() + + # Check unused rules. + self.assertLen(unused_rules, 1) + self.assertEqual(unused_rules[0].module_path, rules[1].module_path) + + # Check that lin1 and lin2 are not quantized. + self.assertFalse(hasattr(quant_model.lin1.kernel, "array")) + self.assertFalse(hasattr(quant_model.lin2.kernel, "array")) + + # Check that layers are quantized. + self.assertIsInstance(quant_model.layers[0].kernel.array, qarray.QArray) + self.assertIsInstance(quant_model.layers[1].kernel.array, qarray.QArray) + + def test_get_unused_rules_before_quantize_model(self): + rules = [ + qconfig.QuantizationRule( + module_path=r"layers/\d+", + weight_qtype="float8_e4m3fn", + act_qtype="float8_e4m3fn", + act_static_scale=False, + ), + ] + provider = ptq.PtqProvider(rules) + with self.assertRaisesRegex( + ValueError, + "Quantization is not completed yet. Please call `quantize_model`" + " before calling `get_unused_rules`.", + ): + provider.get_unused_rules() + + qwix_model.quantize_model(self.model, provider, self.x) + self.assertEmpty(provider.get_unused_rules()) + + +if __name__ == "__main__": + absltest.main()