diff --git a/cpmpy/expressions/globalconstraints.py b/cpmpy/expressions/globalconstraints.py index ed5e977b7..ff9071ad1 100644 --- a/cpmpy/expressions/globalconstraints.py +++ b/cpmpy/expressions/globalconstraints.py @@ -970,11 +970,28 @@ def decompose(self) -> tuple[list[Expression], list[Expression]]: Returns: tuple[list[Expression], list[Expression]]: A tuple containing the constraints representing the constraint value and the defining constraints """ - # there are multiple decompositions possible, Recursively using sum allows it to be efficient for all solvers. - decomp = [sum(self.args[:2]) == 1] - if len(self.args) > 2: - decomp = Xor(decomp + list(self.args[2:])).decompose()[0] - return decomp, [] + # lets first simplify the Xor by removing all constants: + # True Xor x :: ~x and False Xor x :: x + new_args: list[Expression] = [] + parity = False # base case + for a in self.args: + if isinstance(a, Expression) and not isinstance(a, BoolVal): + new_args.append(a) + else: # a constant, don't store but update parity + if a: # True Xor x :: ~x + parity = not parity + if len(new_args) == 0: + return [BoolVal(parity)], [] + if parity: # negate last argument + new_args[-1] = ~new_args[-1] + + # There are multiple decompositions possible, + # recursively using sum allows it to be efficient for all solvers. + prev: Expression = new_args[0] + for a in new_args[1:]: + prev = (prev + a == 1) # recursive pairwise Xor decomposition + + return [prev], [] def value(self) -> Optional[bool]: arrvals = argvals(self.args) diff --git a/tests/test_globalconstraints.py b/tests/test_globalconstraints.py index 35bc3e6c7..77013cc58 100644 --- a/tests/test_globalconstraints.py +++ b/tests/test_globalconstraints.py @@ -859,12 +859,12 @@ def test_div(self): assert all_sols == decomp_sols# same on decision vars assert count == decomp_count# same on all vars - def test_xor(self): + def test_xor(self, solver): bv = cp.boolvar(5) - assert cp.Model(cp.Xor(bv)).solve() + assert cp.Model(cp.Xor(bv)).solve(solver=solver) assert cp.Xor(bv).value() - def test_xor_with_constants(self): + def test_xor_with_constants(self, solver): bvs = cp.boolvar(shape=3) @@ -879,17 +879,18 @@ def test_xor_with_constants(self): expr = cp.Xor(args) model = cp.Model(expr) - assert model.solve() + assert model.solve(solver=solver) assert expr.value() # also check with decomposition model = cp.Model(expr.decompose()) - assert model.solve() + assert model.solve(solver=solver) assert expr.value() # edge case with False constants - assert not cp.Model(cp.Xor([False, False])).solve() - assert not cp.Model(cp.Xor([False, False, False])).solve() + assert not cp.Model(cp.Xor([False, False])).solve(solver=solver) + assert not cp.Model(cp.Xor([False, False, False])).solve(solver=solver) + assert cp.Model(cp.Xor([False, True, False])).solve(solver=solver) def test_ite_with_constants(self): x,y,z = cp.boolvar(shape=3)