Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 80 additions & 47 deletions cpmpy/expressions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self, name: str, arg_list: tuple[Any, ...], has_subexpr: Optional[b
"""
self.name = name
if not isinstance(arg_list, tuple):
warnings.warn(f"DEPRECATED: Argument list of {name} is not a tuple, updated the constructor!", UserWarning)
warnings.warn(f"DEPRECATED: Argument list of {name} is not a tuple, update the constructor!", UserWarning)
arg_list = tuple(arg_list)
self._args = arg_list
self._has_subexpr = has_subexpr
Expand Down Expand Up @@ -189,11 +189,11 @@ def __repr__(self) -> str:
strargs.append(f"{arg}")
return "{}({})".format(self.name, ",".join(strargs))

def __hash__(self):
def __hash__(self) -> int:
return hash(self.__repr__())

def has_subexpr(self):
""" Does it contains nested :class:`Expressions <cpmpy.expressions.core.Expression>` (anything other than a :class:`~cpmpy.expressions.variables._NumVarImpl` or a constant)?
def has_subexpr(self) -> bool:
""" Does it contain nested :class:`Expressions <cpmpy.expressions.core.Expression>` (anything other than a :class:`~cpmpy.expressions.variables._NumVarImpl` or a constant)?
Is of importance when deciding whether certain transformations are needed
along particular paths of the expression tree.
Results are cached for future calls and reset when the expression changes
Expand Down Expand Up @@ -233,35 +233,59 @@ def has_subexpr(self):
self._has_subexpr = False
return False

def is_bool(self):
def is_bool(self) -> bool:
""" is it a Boolean (return type) Operator?
Default: yes
"""
return True

def value(self):
return None # default
def value(self) -> Optional[int]:
return None # default

def get_bounds(self):
def get_bounds(self) -> tuple[int, int]:
if self.is_bool():
return 0, 1 #default for boolean expressions
return 0, 1 # default for boolean expressions
raise NotImplementedError(f"`get_bounds` is not implemented for type {self}")

# keep for backwards compatibility
def deepcopy(self, memodict={}):
""" DEPRECATED: use copy.deepcopy() instead

Will be removed in stable version.
"""
warnings.warn("Deprecated, use copy.deepcopy() instead, will be removed in stable version", DeprecationWarning)
return copy.deepcopy(self, memodict)

# implication constraint: self -> other
# Python does not offer relevant syntax...
# for double implication, use equivalence self == other
def implies(self, other):
# other constant
if is_true_cst(other):
return BoolVal(True)
if is_false_cst(other):
return ~self
return Operator('->', [self, other])
def implies(self, other: ExprLike, simplify: bool = False) -> "Expression":
"""Implication constraint: ``self -> other``.

Python does not offer relevant syntax for implication, call this method instead.
For double reification (<->), use equivalence ``self == other``.

Args:
other (ExprLike): the right-hand-side of the implication
simplify (bool): if True, simplify True/False constants (might remove expressions & their variables from user-view)

Returns:
Expression: the implication constraint or a BoolVal if simplified

Simplification rules:
- self -> True :: BoolVal(True)
- self -> False :: ~self (Boolean inversion)
"""
if not simplify:
return Operator('->', (self, other))

if isinstance(other, Expression):
if isinstance(other, BoolVal): # simplify
if other.args[0]:
return BoolVal(True)
return self.__invert__() # not self
return Operator('->', (self, other))
else: # simplify
assert isinstance(other, bool) or isinstance(other, np.bool_), f"implies: other must be a boolean, got {other}"
if other:
return BoolVal(True)
return self.__invert__() # not self

# Comparisons
def __eq__(self, other):
Expand Down Expand Up @@ -358,7 +382,7 @@ def __radd__(self, other):
return self
return Operator("sum", [other, self])

# substraction
# subtraction
def __sub__(self, other):
# if is_num(other) and other == 0:
# return self
Expand All @@ -382,16 +406,16 @@ def __rmul__(self, other):
return self
return cp.Multiplication(other, self)

# matrix multipliciation TODO?
# matrix multiplication TODO?
#object.__matmul__(self, other)

# other mathematical ones
def __truediv__(self, other):
warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning)
warnings.warn("We only support floordivision, use // instead of /", SyntaxWarning)
return self.__floordiv__(other)

def __rtruediv__(self, other):
warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning)
warnings.warn("We only support floordivision, use // instead of /", SyntaxWarning)
return self.__rfloordiv__(other)

def __floordiv__(self, other):
Expand Down Expand Up @@ -427,19 +451,19 @@ def __rpow__(self, other: Any):
def __neg__(self):
if self.name == 'wsum':
# negate the constant weights
return Operator(self.name, [[-a for a in self.args[0]], self.args[1]])
return Operator("-", [self])
return Operator(self.name, ([-a for a in self.args[0]], self.args[1]))
return Operator("-", (self,))

def __pos__(self):
return self

def __abs__(self):
return cp.Abs(self)

def __invert__(self):
if not (is_boolexpr(self)):
def __invert__(self) -> "Expression":
if not (self.is_bool()):
raise TypeError("Not operator is only allowed on boolean expressions: {0}".format(self))
return Operator("not", [self])
return Operator("not", (self,))

def __bool__(self) -> bool:
raise ValueError(f"__bool__ should not be called on a CPMPy expression {self} as it will always return True\n"
Expand All @@ -452,7 +476,7 @@ class BoolVal(Expression):
"""

def __init__(self, arg: bool|np.bool_) -> None:
arg = bool(arg) # will raise ValueError if not a Boolean-able
arg = bool(arg)
super(BoolVal, self).__init__("boolval", (arg,))

def value(self) -> bool:
Expand Down Expand Up @@ -539,27 +563,36 @@ def __rxor__(self, other):


def has_subexpr(self) -> bool:
""" Does it contains nested Expressions (anything other than a _NumVarImpl or a constant)?
""" Does it contain nested Expressions (anything other than a _NumVarImpl or a constant)?
Is of importance when deciding whether certain transformations are needed
along particular paths of the expression tree.
"""
return False # BoolVal is a wrapper for a python or numpy constant boolean.

def implies(self, other: ExprLike) -> Expression:
my_val: bool = self.args[0]
if isinstance(other, Expression):
assert other.is_bool(), "implies: other must be a boolean expression"
if my_val: # T -> other :: other
return other
return Operator("->", [self, other]) # do not simplify to True, would remove other from user view
else:
# should we check whether it actually is bool and not int?
if my_val: # T -> other :: other
return BoolVal(bool(other))
else: # F -> other :: True
return BoolVal(True)
# note that this can return a BoolVal(True)
def implies(self, other: ExprLike, simplify: bool = False) -> Expression:
"""Implication constraint: ``BoolVal -> other``.

Args:
other (ExprLike): the right-hand-side of the implication
simplify (bool): if True, simplify True/False constants (might remove expressions & their variables from user-view)

Returns:
Expression: the implication constraint or a BoolVal if simplified

Simplification rules:
- BoolVal(True) -> other :: other (BoolVal-ified if needed)
- BoolVal(False) -> other :: BoolVal(True)
"""
if not simplify:
return Operator('->', (self, other))

if self.args[0]:
if not isinstance(other, Expression):
assert isinstance(other, bool) or isinstance(other, np.bool_), f"implies: other must be a boolean, got {other}"
return BoolVal(other)
return other
else:
return BoolVal(True)

class Comparison(Expression):
"""Represents a comparison between two sub-expressions
Expand Down Expand Up @@ -706,7 +739,7 @@ def __repr__(self) -> str:
return f"sum({self.args[0]} * {self.args[1]})"

if len(self.args) == 1:
return "{}({})".format(self.name, self.args[0]) # tuple of size 1 ommited in print
return "{}({})".format(self.name, self.args[0]) # tuple of size 1 omitted in print
elif len(self.args) == 2: # infix printing of two arguments
printname = Operator.printmap.get(self.name, self.name) # default to self.name if not in printmap
arg0, arg1 = self.args
Expand Down Expand Up @@ -793,7 +826,7 @@ def _wsum_should(arg) -> bool:
True if the arg is already a wsum,
or if it is a Multiplication with is_lhs_num
(negation '-' does not mean it SHOULD be a wsum, because then
all substractions are transformed into less readable wsums)
all subtractions are transformed into less readable wsums)
"""
name = getattr(arg, 'name', None)
return name == 'wsum' or (name == 'mul' and arg.is_lhs_num)
Expand Down
17 changes: 12 additions & 5 deletions cpmpy/expressions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
import math
from collections.abc import Iterable # for flatten
from itertools import combinations
from typing import TYPE_CHECKING, TypeGuard, Union, Optional
from typing import TYPE_CHECKING, TypeGuard, overload
from cpmpy.exceptions import IncompleteFunctionError

if TYPE_CHECKING:
# only import for type checking
from cpmpy.expressions.core import ListLike, ExprLike
from cpmpy.expressions.core import ExprLike, Expression
from cpmpy.expressions.variables import NDVarArray


def is_bool(arg):
Expand Down Expand Up @@ -208,17 +209,23 @@ def get_bounds(expr):
return int(expr), int(expr)
return math.floor(expr), math.ceil(expr)

def implies(expr, other):
# first to are declarations for typing purposes only
@overload
def implies(expr: NDVarArray, other: ExprLike, simplify: bool = False) -> NDVarArray: ...
@overload
def implies(expr: Expression|bool|np.bool_, other: ExprLike, simplify: bool = False) -> Expression: ...

def implies(expr: NDVarArray|Expression|bool|np.bool_, other: ExprLike, simplify: bool = False) -> NDVarArray|ExprLike:
""" like :func:`~cpmpy.expressions.core.Expression.implies`, but also safe to use for non-expressions """
if isinstance(expr, (cp.expressions.core.Expression, cp.expressions.variables.NDVarArray)):
# both implement .implies()
return expr.implies(other)
return expr.implies(other, simplify=simplify)
elif is_true_cst(expr):
return other
elif is_false_cst(expr):
return cp.BoolVal(True)
else:
return expr.implies(other)
raise ValueError(f"implies: expr must be an Expression or a boolean, got {type(expr)}")

# Specific stuff for scheduling constraints

Expand Down
6 changes: 4 additions & 2 deletions cpmpy/expressions/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,10 @@ def __xor__(self, other):
def __rxor__(self, other):
return self._vectorized(other, '__rxor__')

def implies(self, other):
return self._vectorized(other, 'implies')
def implies(self, other, simplify=False):
if not isinstance(other, Iterable):
other = [other] * len(self)
return cpm_array([s.implies(o, simplify=simplify) for s, o in zip(self, other)])

#in __contains__(self, value) Check membership
# CANNOT meaningfully overwrite, python always returns True/False
Expand Down
4 changes: 2 additions & 2 deletions cpmpy/transformations/int2bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _encode_expr(ivarmap, expr, encoding, csemap=None):
p, consequent = expr.args
constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding, csemap=csemap)
return (
[p.implies(constraint) for constraint in constraints],
[p.implies(constraint, simplify=True) for constraint in constraints],
domain_constraints,
)
elif isinstance(expr, Comparison):
Expand Down Expand Up @@ -355,7 +355,7 @@ def encode_domain_constraint(self):
if len(self._xs) <= 1:
return []
# Encode implication chain `x>=d -> x>=d-1` (using `zip` to create a sliding window)
return [curr.implies(prev) for prev, curr in zip(self._xs, self._xs[1:])]
return [curr.implies(prev, simplify=True) for prev, curr in zip(self._xs, self._xs[1:])]

def _offset(self, d):
return d - self._x.lb - 1
Expand Down
4 changes: 4 additions & 0 deletions tests/test_trans_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def test_bool_ops(self):

expr = Operator("->", [self.bvs[0], True])
assert str(self.transform(expr)) == "[boolval(True)]"
expr = Operator("->", [self.bvs[0], BoolVal(True)])
assert str(self.transform(expr)) == "[boolval(True)]"
expr = Operator("->", [self.bvs[0], False])
assert str(self.transform(expr)) == "[~bv[0]]"
expr = Operator("->", [self.bvs[0], BoolVal(False)])
assert str(self.transform(expr)) == "[~bv[0]]"
expr = Operator("->", [True, self.bvs[0]])
assert str(self.transform(expr)) == "[bv[0]]"
expr = Operator("->", [False, self.bvs[0]])
Expand Down
Loading