diff --git a/cpmpy/expressions/core.py b/cpmpy/expressions/core.py index 6934a6d4a..061d2b5a5 100644 --- a/cpmpy/expressions/core.py +++ b/cpmpy/expressions/core.py @@ -141,7 +141,7 @@ def __init__(self, name: str, arg_list: tuple[Any, ...], has_subexpr: Optional[b """ self.name = name if not isinstance(arg_list, tuple): - warnings.warn(f"DEPRECATED: Argument list of {name} is not a tuple, updated the constructor!", UserWarning) + warnings.warn(f"DEPRECATED: Argument list of {name} is not a tuple, update the constructor!", UserWarning) arg_list = tuple(arg_list) self._args = arg_list self._has_subexpr = has_subexpr @@ -189,11 +189,11 @@ def __repr__(self) -> str: strargs.append(f"{arg}") return "{}({})".format(self.name, ",".join(strargs)) - def __hash__(self): + def __hash__(self) -> int: return hash(self.__repr__()) - def has_subexpr(self): - """ Does it contains nested :class:`Expressions ` (anything other than a :class:`~cpmpy.expressions.variables._NumVarImpl` or a constant)? + def has_subexpr(self) -> bool: + """ Does it contain nested :class:`Expressions ` (anything other than a :class:`~cpmpy.expressions.variables._NumVarImpl` or a constant)? Is of importance when deciding whether certain transformations are needed along particular paths of the expression tree. Results are cached for future calls and reset when the expression changes @@ -233,35 +233,59 @@ def has_subexpr(self): self._has_subexpr = False return False - def is_bool(self): + def is_bool(self) -> bool: """ is it a Boolean (return type) Operator? Default: yes """ return True - def value(self): - return None # default + def value(self) -> Optional[int]: + return None # default - def get_bounds(self): + def get_bounds(self) -> tuple[int, int]: if self.is_bool(): - return 0, 1 #default for boolean expressions + return 0, 1 # default for boolean expressions raise NotImplementedError(f"`get_bounds` is not implemented for type {self}") - # keep for backwards compatibility def deepcopy(self, memodict={}): + """ DEPRECATED: use copy.deepcopy() instead + + Will be removed in stable version. + """ warnings.warn("Deprecated, use copy.deepcopy() instead, will be removed in stable version", DeprecationWarning) return copy.deepcopy(self, memodict) - # implication constraint: self -> other - # Python does not offer relevant syntax... - # for double implication, use equivalence self == other - def implies(self, other): - # other constant - if is_true_cst(other): - return BoolVal(True) - if is_false_cst(other): - return ~self - return Operator('->', [self, other]) + def implies(self, other: ExprLike, simplify: bool = False) -> "Expression": + """Implication constraint: ``self -> other``. + + Python does not offer relevant syntax for implication, call this method instead. + For double reification (<->), use equivalence ``self == other``. + + Args: + other (ExprLike): the right-hand-side of the implication + simplify (bool): if True, simplify True/False constants (might remove expressions & their variables from user-view) + + Returns: + Expression: the implication constraint or a BoolVal if simplified + + Simplification rules: + - self -> True :: BoolVal(True) + - self -> False :: ~self (Boolean inversion) + """ + if not simplify: + return Operator('->', (self, other)) + + if isinstance(other, Expression): + if isinstance(other, BoolVal): # simplify + if other.args[0]: + return BoolVal(True) + return self.__invert__() # not self + return Operator('->', (self, other)) + else: # simplify + assert isinstance(other, bool) or isinstance(other, np.bool_), f"implies: other must be a boolean, got {other}" + if other: + return BoolVal(True) + return self.__invert__() # not self # Comparisons def __eq__(self, other): @@ -358,7 +382,7 @@ def __radd__(self, other): return self return Operator("sum", [other, self]) - # substraction + # subtraction def __sub__(self, other): # if is_num(other) and other == 0: # return self @@ -382,16 +406,16 @@ def __rmul__(self, other): return self return cp.Multiplication(other, self) - # matrix multipliciation TODO? + # matrix multiplication TODO? #object.__matmul__(self, other) # other mathematical ones def __truediv__(self, other): - warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning) + warnings.warn("We only support floordivision, use // instead of /", SyntaxWarning) return self.__floordiv__(other) def __rtruediv__(self, other): - warnings.warn("We only support floordivision, use // in stead of /", SyntaxWarning) + warnings.warn("We only support floordivision, use // instead of /", SyntaxWarning) return self.__rfloordiv__(other) def __floordiv__(self, other): @@ -427,8 +451,8 @@ def __rpow__(self, other: Any): def __neg__(self): if self.name == 'wsum': # negate the constant weights - return Operator(self.name, [[-a for a in self.args[0]], self.args[1]]) - return Operator("-", [self]) + return Operator(self.name, ([-a for a in self.args[0]], self.args[1])) + return Operator("-", (self,)) def __pos__(self): return self @@ -436,10 +460,10 @@ def __pos__(self): def __abs__(self): return cp.Abs(self) - def __invert__(self): - if not (is_boolexpr(self)): + def __invert__(self) -> "Expression": + if not (self.is_bool()): raise TypeError("Not operator is only allowed on boolean expressions: {0}".format(self)) - return Operator("not", [self]) + return Operator("not", (self,)) def __bool__(self) -> bool: raise ValueError(f"__bool__ should not be called on a CPMPy expression {self} as it will always return True\n" @@ -452,7 +476,7 @@ class BoolVal(Expression): """ def __init__(self, arg: bool|np.bool_) -> None: - arg = bool(arg) # will raise ValueError if not a Boolean-able + arg = bool(arg) super(BoolVal, self).__init__("boolval", (arg,)) def value(self) -> bool: @@ -539,27 +563,36 @@ def __rxor__(self, other): def has_subexpr(self) -> bool: - """ Does it contains nested Expressions (anything other than a _NumVarImpl or a constant)? + """ Does it contain nested Expressions (anything other than a _NumVarImpl or a constant)? Is of importance when deciding whether certain transformations are needed along particular paths of the expression tree. """ return False # BoolVal is a wrapper for a python or numpy constant boolean. - def implies(self, other: ExprLike) -> Expression: - my_val: bool = self.args[0] - if isinstance(other, Expression): - assert other.is_bool(), "implies: other must be a boolean expression" - if my_val: # T -> other :: other - return other - return Operator("->", [self, other]) # do not simplify to True, would remove other from user view - else: - # should we check whether it actually is bool and not int? - if my_val: # T -> other :: other - return BoolVal(bool(other)) - else: # F -> other :: True - return BoolVal(True) - # note that this can return a BoolVal(True) + def implies(self, other: ExprLike, simplify: bool = False) -> Expression: + """Implication constraint: ``BoolVal -> other``. + + Args: + other (ExprLike): the right-hand-side of the implication + simplify (bool): if True, simplify True/False constants (might remove expressions & their variables from user-view) + + Returns: + Expression: the implication constraint or a BoolVal if simplified + Simplification rules: + - BoolVal(True) -> other :: other (BoolVal-ified if needed) + - BoolVal(False) -> other :: BoolVal(True) + """ + if not simplify: + return Operator('->', (self, other)) + + if self.args[0]: + if not isinstance(other, Expression): + assert isinstance(other, bool) or isinstance(other, np.bool_), f"implies: other must be a boolean, got {other}" + return BoolVal(other) + return other + else: + return BoolVal(True) class Comparison(Expression): """Represents a comparison between two sub-expressions @@ -706,7 +739,7 @@ def __repr__(self) -> str: return f"sum({self.args[0]} * {self.args[1]})" if len(self.args) == 1: - return "{}({})".format(self.name, self.args[0]) # tuple of size 1 ommited in print + return "{}({})".format(self.name, self.args[0]) # tuple of size 1 omitted in print elif len(self.args) == 2: # infix printing of two arguments printname = Operator.printmap.get(self.name, self.name) # default to self.name if not in printmap arg0, arg1 = self.args @@ -793,7 +826,7 @@ def _wsum_should(arg) -> bool: True if the arg is already a wsum, or if it is a Multiplication with is_lhs_num (negation '-' does not mean it SHOULD be a wsum, because then - all substractions are transformed into less readable wsums) + all subtractions are transformed into less readable wsums) """ name = getattr(arg, 'name', None) return name == 'wsum' or (name == 'mul' and arg.is_lhs_num) diff --git a/cpmpy/expressions/utils.py b/cpmpy/expressions/utils.py index 37ad904d5..c47b4e638 100644 --- a/cpmpy/expressions/utils.py +++ b/cpmpy/expressions/utils.py @@ -35,12 +35,13 @@ import math from collections.abc import Iterable # for flatten from itertools import combinations -from typing import TYPE_CHECKING, TypeGuard, Union, Optional +from typing import TYPE_CHECKING, TypeGuard, overload from cpmpy.exceptions import IncompleteFunctionError if TYPE_CHECKING: # only import for type checking - from cpmpy.expressions.core import ListLike, ExprLike + from cpmpy.expressions.core import ExprLike, Expression + from cpmpy.expressions.variables import NDVarArray def is_bool(arg): @@ -208,17 +209,23 @@ def get_bounds(expr): return int(expr), int(expr) return math.floor(expr), math.ceil(expr) -def implies(expr, other): +# first to are declarations for typing purposes only +@overload +def implies(expr: NDVarArray, other: ExprLike, simplify: bool = False) -> NDVarArray: ... +@overload +def implies(expr: Expression|bool|np.bool_, other: ExprLike, simplify: bool = False) -> Expression: ... + +def implies(expr: NDVarArray|Expression|bool|np.bool_, other: ExprLike, simplify: bool = False) -> NDVarArray|ExprLike: """ like :func:`~cpmpy.expressions.core.Expression.implies`, but also safe to use for non-expressions """ if isinstance(expr, (cp.expressions.core.Expression, cp.expressions.variables.NDVarArray)): # both implement .implies() - return expr.implies(other) + return expr.implies(other, simplify=simplify) elif is_true_cst(expr): return other elif is_false_cst(expr): return cp.BoolVal(True) else: - return expr.implies(other) + raise ValueError(f"implies: expr must be an Expression or a boolean, got {type(expr)}") # Specific stuff for scheduling constraints diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index d3258a357..541f2f68a 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -773,8 +773,10 @@ def __xor__(self, other): def __rxor__(self, other): return self._vectorized(other, '__rxor__') - def implies(self, other): - return self._vectorized(other, 'implies') + def implies(self, other, simplify=False): + if not isinstance(other, Iterable): + other = [other] * len(self) + return cpm_array([s.implies(o, simplify=simplify) for s, o in zip(self, other)]) #in __contains__(self, value) Check membership # CANNOT meaningfully overwrite, python always returns True/False diff --git a/cpmpy/transformations/int2bool.py b/cpmpy/transformations/int2bool.py index 1d45185db..5ca9dd5d1 100644 --- a/cpmpy/transformations/int2bool.py +++ b/cpmpy/transformations/int2bool.py @@ -50,7 +50,7 @@ def _encode_expr(ivarmap, expr, encoding, csemap=None): p, consequent = expr.args constraints, domain_constraints = _encode_expr(ivarmap, consequent, encoding, csemap=csemap) return ( - [p.implies(constraint) for constraint in constraints], + [p.implies(constraint, simplify=True) for constraint in constraints], domain_constraints, ) elif isinstance(expr, Comparison): @@ -355,7 +355,7 @@ def encode_domain_constraint(self): if len(self._xs) <= 1: return [] # Encode implication chain `x>=d -> x>=d-1` (using `zip` to create a sliding window) - return [curr.implies(prev) for prev, curr in zip(self._xs, self._xs[1:])] + return [curr.implies(prev, simplify=True) for prev, curr in zip(self._xs, self._xs[1:])] def _offset(self, d): return d - self._x.lb - 1 diff --git a/tests/test_trans_simplify.py b/tests/test_trans_simplify.py index 4d03d6a55..69761e918 100644 --- a/tests/test_trans_simplify.py +++ b/tests/test_trans_simplify.py @@ -25,8 +25,12 @@ def test_bool_ops(self): expr = Operator("->", [self.bvs[0], True]) assert str(self.transform(expr)) == "[boolval(True)]" + expr = Operator("->", [self.bvs[0], BoolVal(True)]) + assert str(self.transform(expr)) == "[boolval(True)]" expr = Operator("->", [self.bvs[0], False]) assert str(self.transform(expr)) == "[~bv[0]]" + expr = Operator("->", [self.bvs[0], BoolVal(False)]) + assert str(self.transform(expr)) == "[~bv[0]]" expr = Operator("->", [True, self.bvs[0]]) assert str(self.transform(expr)) == "[bv[0]]" expr = Operator("->", [False, self.bvs[0]])