Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions cpmpy/expressions/globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,11 +970,28 @@ def decompose(self) -> tuple[list[Expression], list[Expression]]:
Returns:
tuple[list[Expression], list[Expression]]: A tuple containing the constraints representing the constraint value and the defining constraints
"""
# there are multiple decompositions possible, Recursively using sum allows it to be efficient for all solvers.
decomp = [sum(self.args[:2]) == 1]
if len(self.args) > 2:
decomp = Xor(decomp + list(self.args[2:])).decompose()[0]
return decomp, []
# lets first simplify the Xor by removing all constants:
# True Xor x :: ~x and False Xor x :: x
new_args: list[Expression] = []
parity = False # base case
for a in self.args:
if isinstance(a, Expression) and not isinstance(a, BoolVal):
new_args.append(a)
else: # a constant, don't store but update parity
if a: # True Xor x :: ~x
parity = not parity
if len(new_args) == 0:
return [BoolVal(parity)], []
if parity: # negate last argument
new_args[-1] = ~new_args[-1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if new_args[-1] is an expression? Then it will introduce a complex negation, right?
Ideally, we want to negate a variable here (see the .negate() of xor).

If there is no variable in any of the args, we can indeed negate this one, but we'll have to run push_down_negation again once we move that transformation before decompose (which I currently have in a local branch)


# There are multiple decompositions possible,
# recursively using sum allows it to be efficient for all solvers.
prev: Expression = new_args[0]
for a in new_args[1:]:
prev = (prev + a == 1) # recursive pairwise Xor decomposition

return [prev], []

def value(self) -> Optional[bool]:
arrvals = argvals(self.args)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_globalconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,12 @@ def test_div(self):
assert all_sols == decomp_sols# same on decision vars
assert count == decomp_count# same on all vars

def test_xor(self):
def test_xor(self, solver):
bv = cp.boolvar(5)
assert cp.Model(cp.Xor(bv)).solve()
assert cp.Model(cp.Xor(bv)).solve(solver=solver)
assert cp.Xor(bv).value()

def test_xor_with_constants(self):
def test_xor_with_constants(self, solver):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you had to use the pytest fixture for this to work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agent says that tests/conftest.py defines it for all tests... @ThomSerg can you confirm that just adding 'solver' here (without the fixture being mentioned on the function nor the class) will work because the conftest does it at project level?


bvs = cp.boolvar(shape=3)

Expand All @@ -879,17 +879,18 @@ def test_xor_with_constants(self):
expr = cp.Xor(args)
model = cp.Model(expr)

assert model.solve()
assert model.solve(solver=solver)
assert expr.value()

# also check with decomposition
model = cp.Model(expr.decompose())
assert model.solve()
assert model.solve(solver=solver)
assert expr.value()

# edge case with False constants
assert not cp.Model(cp.Xor([False, False])).solve()
assert not cp.Model(cp.Xor([False, False, False])).solve()
assert not cp.Model(cp.Xor([False, False])).solve(solver=solver)
assert not cp.Model(cp.Xor([False, False, False])).solve(solver=solver)
assert cp.Model(cp.Xor([False, True, False])).solve(solver=solver)

def test_ite_with_constants(self):
x,y,z = cp.boolvar(shape=3)
Expand Down
Loading