diff --git a/cpmpy/solvers/gurobi.py b/cpmpy/solvers/gurobi.py index 7f83216d0..b03c7fb69 100644 --- a/cpmpy/solvers/gurobi.py +++ b/cpmpy/solvers/gurobi.py @@ -42,9 +42,12 @@ ============== """ +import cpmpy as cp from typing import Optional, List import warnings +import math + from .solver_interface import SolverInterface, SolverStatus, ExitStatus, Callback from ..exceptions import NotSupportedError from ..expressions.core import Expression, Comparison, Operator, BoolVal @@ -52,9 +55,10 @@ from ..expressions.variables import _BoolVarImpl, NegBoolView, _IntVarImpl, _NumVarImpl, intvar from ..expressions.globalconstraints import DirectConstraint from ..transformations.comparison import only_numexpr_equality -from ..transformations.flatten_model import flatten_constraint, flatten_objective +from ..transformations.flatten_model import flatten_constraint, flatten_objective, get_or_make_var_or_list from ..transformations.get_variables import get_variables from ..transformations.linearize import linearize_constraint, linearize_reified_variables, only_positive_bv, only_positive_bv_wsum, decompose_linear, decompose_linear_objective +from ..transformations.decompose_global import decompose_in_tree from ..transformations.normalize import toplevel_list from ..transformations.reification import only_implies, reify_rewrite, only_bv_reifies from ..transformations.safening import no_partial_functions, safen_objective @@ -192,6 +196,9 @@ def solve(self, time_limit:Optional[float]=None, solution_callback=None, **kwarg for param, val in kwargs.items(): self.grb_model.setParam(param, val) + # write LP file for debugging + self.grb_model.write("/tmp/model.lp") + _ = self.grb_model.optimize(callback=solution_callback) grb_objective = self.grb_model.getObjective() @@ -259,7 +266,7 @@ def solver_var(self, cpm_var): # special case, negative-bool-view. Should be eliminated in linearize if isinstance(cpm_var, NegBoolView): - raise NotSupportedError("Negative literals should not be left as part of any equation. Please report.") + return 1 - self.solver_var(cpm_var._bv) # create if it does not exit if cpm_var not in self._varmap: @@ -297,7 +304,7 @@ def objective(self, expr, minimize=True): supported=self.supported_global_constraints, supported_reified=self.supported_reified_global_constraints, csemap=self._csemap) - obj, flat_cons = flatten_objective(obj, csemap=self._csemap) + obj, flat_cons = flatten_objective(obj, csemap=self._csemap, supported={"pow", "mul"}) obj = only_positive_bv_wsum(obj) # remove negboolviews self.add(safe_cons + decomp_cons + flat_cons) @@ -358,20 +365,37 @@ def transform(self, cpm_expr): # apply transformations, then post internally # expressions have to be linearized to fit in MIP model. See /transformations/linearize cpm_cons = toplevel_list(cpm_expr) - cpm_cons = no_partial_functions(cpm_cons, safen_toplevel={"mod", "div", "element"}) # linearize and decompose expect safe exprs + cpm_cons = no_partial_functions(cpm_cons, safen_toplevel=frozenset(["mod", "div", "element"])) # linearize and decompose expect safe exprs cpm_cons = decompose_linear(cpm_cons, supported=self.supported_global_constraints, supported_reified=self.supported_reified_global_constraints, csemap=self._csemap) - cpm_cons = flatten_constraint(cpm_cons, csemap=self._csemap) # flat normal form + cpm_cons = flatten_constraint(cpm_cons, csemap=self._csemap, supported=frozenset(["mul", "pow", "-", "sum"])) # flat normal form cpm_cons = reify_rewrite(cpm_cons, supported=frozenset(['sum', 'wsum']), csemap=self._csemap) # constraints that support reification cpm_cons = only_numexpr_equality(cpm_cons, supported=frozenset(["sum", "wsum", "sub"]), csemap=self._csemap) # supports >, <, != - cpm_cons = linearize_reified_variables(cpm_cons, min_values=2, csemap=self._csemap) + + # cpm_cons = linearize_reified_variables(cpm_cons, min_values=2, csemap=self._csemap) cpm_cons = only_bv_reifies(cpm_cons, csemap=self._csemap) - cpm_cons = only_implies(cpm_cons, csemap=self._csemap) # anything that can create full reif should go above... + cpm_cons = only_implies( + cpm_cons, + csemap=self._csemap, + is_supported=lambda cpm_expr: + ( + cpm_expr.name == "==" and + ( + (isinstance(cpm_expr.args[1], Operator) and cpm_expr.args[1].name in {"or", "and"}) or + isinstance(cpm_expr.args[1], _BoolVarImpl) + ) + ) or ( + cpm_expr.name == "->" and isinstance(cpm_expr.args[1], Comparison) + ) + ) # anything that can create full reif should go above... + # gurobi does not round towards zero, so no 'div' in supported set: https://github.com/CPMpy/cpmpy/pull/593#issuecomment-2786707188 - cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"sum", "wsum","->","sub","min","max","mul","abs","pow"}), csemap=self._csemap) # the core of the MIP-linearization + cpm_cons = linearize_constraint(cpm_cons, supported=frozenset({"-","sum", "wsum","->","sub","min","max","mul","abs","pow","and","or"}), csemap=self._csemap) # the core of the MIP-linearization + cpm_cons = only_positive_bv(cpm_cons, csemap=self._csemap) # after linearization, rewrite ~bv into 1-bv + # TODO don't rewrite ~p -> LinCons return cpm_cons def add(self, cpm_expr_orig): @@ -392,13 +416,113 @@ def add(self, cpm_expr_orig): :return: self """ - from gurobipy import GRB + import gurobipy as gp + + def add(cpm_expr): + """Recursively create a Gurobi constraint from a CPMpy expression.""" + + def add_(cpm_expr, depth): + indent = " " * depth + depth+=1 + self.grb_model.update() + print(f"{indent}Con:", cpm_expr, type(cpm_expr)) + + if is_num(cpm_expr): + return int(cpm_expr) + elif isinstance(cpm_expr, NegBoolView): + return 1 - self.solver_var(cpm_expr._bv) + elif isinstance(cpm_expr, _NumVarImpl): + return self.solver_var(cpm_expr) + elif isinstance(cpm_expr, Operator): + match cpm_expr.name: + case "or": # TODO only usefull if we can represent an or as a sum of binary vars, which is not generally correct + return gp.or_(add_(arg, depth) for arg in cpm_expr.args) + case "and": + return gp.and_(add_(arg, depth) for arg in cpm_expr.args) + case "->": # Gurobi indicator constraint: (Var == 0|1) >> (LinExpr sense LinExpr) + a, b = cpm_expr.args + assert isinstance(a, _BoolVarImpl), f"Implication constraint {cpm_expr} must have BoolVar as lhs, but had {a}" + # To not complicate linearize, we could see a _BoolVarImpl as consequent, which we rewrite here to a unary LinExpr + is_pos = not isinstance(a, NegBoolView) + consequent = add_(b, depth) >= 1 if isinstance(b, _BoolVarImpl) else add_(b, depth) + return (add_(a if is_pos else a._bv, depth) == int(is_pos)) >> consequent + case "not": + return 1 - add_(cpm_expr.args[0], depth) + case "-": + return -add_(cpm_expr.args[0], depth=depth) + case "sum": + return sum(add_(arg, depth) for arg in cpm_expr.args) + case "wsum": + return sum(weight * add_(arg, depth) for weight, arg in zip(cpm_expr.args[0], cpm_expr.args[1])) + case "sub": + return add_(cpm_expr.args[0], depth) - add_(cpm_expr.args[1], depth) + case "div": + assert False, "TODO" + # TODO + # if not is_num(lhs.args[1]): + # raise NotSupportedError(f"Gurobi only supports division by constants, but got {lhs.args[1]}") + elif isinstance(cpm_expr, Comparison): + a, b = add_(cpm_expr.args[0], depth), add_(cpm_expr.args[1], depth) + match cpm_expr.name: + case "==": + if isinstance(a, gp.NLExpr): + # if flattening led to a non-linear expression, then it has to be constraint `y == f(x)` with `y` a `Var` + y = b if isinstance(b, gp.Var) else self.grb_model.addVar(lb=b, ub=b) + # TODO unclear why this is no longer normalized to be a constant `b` (same below) + return y == a + else: + # Else, this is a function constraint + # Note: Gurobi functions are called by `y == f(x)`, like normalized CPMpy boolexprs, but CPMpy numexprs are normalized to `f(x) == y` (e.g. `abs(x) == IV0`), so they need to be flipped + y, fx = (a, b) if cpm_expr.args[0].is_bool() else (b, a) + y = y if isinstance(y, gp.Var) else self.grb_model.addVar(lb=y, ub=y) + return y == fx + case "<=": + return a <= b + case ">=": + return a >= b + case _: + raise Exception(f"Expected comparator to be ==,<=,>= in Comparison expression {cpm_expr}, but was {cpm_expr.name}") + elif isinstance(cpm_expr, cp.expressions.globalfunctions.GlobalFunction): + args = [add_(a, depth) for a in cpm_expr.args] + match cpm_expr.name: + case "mul": + return args[0] * args[1] + case "pow": + return args[0] ** args[1] + # remaining global function cannot be part of the expression tree + case "abs": # y = abs(x) + # TODO we could support this inside the expression tree with sqrt(pow(x,2))? + return gp.abs_(args[0]) + case "min": + return gp.min_(args) + case "max": + return gp.max_(args) + + elif isinstance(cpm_expr, DirectConstraint): + cpm_expr.callSolver(self, self.grb_model) + return True + else: + raise NotImplementedError(f"add_() not implemented for {cpm_expr}, {type(cpm_expr)}, {getattr(cpm_expr, 'name', None)}") + + grb_expr = add_(cpm_expr, 0) + if isinstance(grb_expr, (gp.Var, gp.LinExpr)): + # If add() returned a Gurobi Var (not a constraint), wrap it as >= 1 + return grb_expr >= 1 + elif isinstance(grb_expr, gp.TempConstr): + return grb_expr + else: + return self.grb_model.addVar(lb=1, ub=1) == grb_expr # add new user vars to the set get_variables(cpm_expr_orig, collect=self.user_vars) # transform and post the constraints for cpm_expr in self.transform(cpm_expr_orig): + con = add(cpm_expr) + self.grb_model.update() + print("out", con) + self.grb_model.addConstr(con) + continue # Comparisons: only numeric ones as 'only_implies()' has removed the '==' reification for Boolean expressions # numexpr `comp` bvar|const @@ -487,8 +611,11 @@ def add(self, cpm_expr_orig): elif isinstance(cpm_expr, DirectConstraint): cpm_expr.callSolver(self, self.grb_model) + elif isinstance(cpm_expr, _BoolVarImpl): + self.grb_model.addConstr(self.solver_var(cpm_expr) >= 1) + else: - raise NotImplementedError(cpm_expr) # if you reach this... please report on github + raise NotImplementedError(f"Please report unsupported constraint in Gurobi interface: {cpm_expr} of type {type(cpm_expr)}") # if you reach this... please report on github return self __add__ = add # avoid redirect in superclass diff --git a/cpmpy/transformations/flatten_model.py b/cpmpy/transformations/flatten_model.py index d8344fd8f..b7d44c25d 100644 --- a/cpmpy/transformations/flatten_model.py +++ b/cpmpy/transformations/flatten_model.py @@ -121,7 +121,7 @@ def flatten_model(orig_model, csemap=None): return cp.Model(*basecons, maximize=newobj) -def flatten_constraint(expr, csemap=None): +def flatten_constraint(expr, csemap=None, supported={}, reified=False): """ input is any expression; except is_num(), pure _NumVarImpl, or Operator/GlobalConstraint with not is_bool() @@ -156,7 +156,9 @@ def flatten_constraint(expr, csemap=None): Var -> Boolexpr (CPMpy class 'Operator', is_bool()) """ # does not type-check that arguments are bool... Could do now with expr.is_bool()! - if expr.name == 'or': + if expr.name in supported: + newlist.extend(flatten_constraint(expr.args, csemap=csemap, reified=reified)) + elif expr.name == 'or': # rewrites that avoid auxiliary var creation, should go to normalize? # in case of an implication in a disjunction, merge in if builtins.any(isinstance(a, Operator) and a.name == '->' for a in expr.args): @@ -165,7 +167,7 @@ def flatten_constraint(expr, csemap=None): if isinstance(a, Operator) and a.name == '->': newargs[i:i+1] = [~a.args[0],a.args[1]] # there could be nested implications - newlist.extend(flatten_constraint(Operator('or', newargs), csemap=csemap)) + newlist.extend(flatten_constraint(Operator('or', newargs), csemap=csemap, reified=reified)) continue # conjunctions in disjunctions could be split out by applying distributivity, # but this would explode the number of constraints in favour of having less auxiliary variables. @@ -176,30 +178,30 @@ def flatten_constraint(expr, csemap=None): if expr.args[1].name == 'and': a1s = expr.args[1].args a0 = expr.args[0] - newlist.extend(flatten_constraint([a0.implies(a1) for a1 in a1s], csemap=csemap)) + newlist.extend(flatten_constraint([a0.implies(a1) for a1 in a1s], csemap=csemap, reified=True)) continue # 2) if lhs is 'or' then or([a01..a0n])->a1 :: ~a1->and([~a01..~a0n] and split elif expr.args[0].name == 'or': a0s = expr.args[0].args a1 = expr.args[1] - newlist.extend(flatten_constraint([(~a1).implies(~a0) for a0 in a0s], csemap=csemap)) + newlist.extend(flatten_constraint([(~a1).implies(~a0) for a0 in a0s], csemap=csemap, reified=True)) continue # 2b) if lhs is ->, like 'or': a01->a02->a1 :: (~a01|a02)->a1 :: ~a1->a01,~a1->~a02 elif expr.args[0].name == '->': a01,a02 = expr.args[0].args a1 = expr.args[1] - newlist.extend(flatten_constraint([(~a1).implies(a01), (~a1).implies(~a02)], csemap=csemap)) + newlist.extend(flatten_constraint([(~a1).implies(a01), (~a1).implies(~a02)], csemap=csemap, reified=True)) continue # ->, allows a boolexpr on one side elif isinstance(expr.args[0], _BoolVarImpl): # LHS is var, ensure RHS is normalized 'Boolexpr' lhs,lcons = expr.args[0], () - rhs,rcons = normalized_boolexpr(expr.args[1], csemap=csemap) + rhs,rcons = normalized_boolexpr(expr.args[1], csemap=csemap, supported=supported, reified=True) else: # make LHS normalized 'Boolexpr', RHS must be a var - lhs,lcons = normalized_boolexpr(expr.args[0], csemap=csemap) - rhs,rcons = get_or_make_var(expr.args[1], csemap=csemap) + lhs,lcons = normalized_boolexpr(expr.args[0], csemap=csemap, supported=supported, reified=True) + rhs,rcons = get_or_make_var(expr.args[1], csemap=csemap, supported=supported, reified=True) newlist.append(Operator(expr.name, (lhs,rhs))) newlist.extend(lcons) @@ -210,7 +212,7 @@ def flatten_constraint(expr, csemap=None): # if none of the above cases + continue matched: # a normalizable boolexpr - (con, flatcons) = normalized_boolexpr(expr, csemap=csemap) + (con, flatcons) = normalized_boolexpr(expr, csemap=csemap, supported=supported, reified=reified) newlist.append(con) newlist.extend(flatcons) @@ -254,18 +256,20 @@ def flatten_constraint(expr, csemap=None): continue # ensure rhs is var - (rvar, rcons) = get_or_make_var(rexpr, csemap=csemap) + (rvar, rcons) = get_or_make_var(rexpr, csemap=csemap, reified=reified) # Reification (double implication): Boolexpr == Var # normalize the lhs (does not have to be a var, hence we call normalize instead of get_or_make_var if exprname == '==' and lexpr.is_bool(): if rvar.is_bool(): # this is a reification - (lhs, lcons) = normalized_boolexpr(lexpr, csemap=csemap) + (lhs, lcons) = normalized_boolexpr(lexpr, csemap=csemap, supported=supported, reified=reified) else: # integer comparison - (lhs, lcons) = get_or_make_var(lexpr, csemap=csemap) + (lhs, lcons) = get_or_make_var(lexpr, csemap=csemap, supported=supported, reified=reified) + elif expr.name == "!=": # TODO ; this risks making a reified LinCons with a QuadExpr + (lhs, lcons) = normalized_numexpr(lexpr, csemap=csemap, reified=reified) else: - (lhs, lcons) = normalized_numexpr(lexpr, csemap=csemap) + (lhs, lcons) = normalized_numexpr(lexpr, csemap=csemap, supported=supported, reified=reified) newlist.append(Comparison(exprname, lhs, rvar)) newlist.extend(lcons) @@ -286,7 +290,7 @@ def flatten_constraint(expr, csemap=None): return newlist -def flatten_objective(expr, supported=frozenset(["sum", "wsum"]), csemap=None): +def flatten_objective(expr, supported=frozenset(["sum", "wsum"]), csemap=None, reified=False): """ - Decision variable: Var - Linear: @@ -301,12 +305,12 @@ def flatten_objective(expr, supported=frozenset(["sum", "wsum"]), csemap=None): raise Exception(f"Objective expects a single variable/expression, not a list of expressions: {expr}") expr = simplify_boolean([expr])[0] - (flatexpr, flatcons) = normalized_numexpr(expr, csemap=csemap) # might rewrite expr into a (w)sum + (flatexpr, flatcons) = normalized_numexpr(expr, csemap=csemap, supported=supported) # might rewrite expr into a (w)sum if isinstance(flatexpr, Expression) and flatexpr.name in supported: return (flatexpr, flatcons) else: # any other numeric expression, - var, cons = get_or_make_var(flatexpr, csemap=csemap) + var, cons = get_or_make_var(flatexpr, csemap=csemap, reified=reified) return (var, cons+flatcons) @@ -323,7 +327,7 @@ def __is_flat_var_or_list(arg): is_any_list(arg) and all(__is_flat_var_or_list(el) for el in arg) or \ is_star(arg) -def get_or_make_var(expr, csemap=None): +def get_or_make_var(expr, csemap=None, supported={}, reified=False): """ Must return a variable, and list of flat normal constraints Determines whether this is a Boolean or Integer variable and returns @@ -341,7 +345,7 @@ def get_or_make_var(expr, csemap=None): if expr.is_bool(): # normalize expr into a boolexpr LHS, reify LHS == bvar - (flatexpr, flatcons) = normalized_boolexpr(expr, csemap=csemap) + (flatexpr, flatcons) = normalized_boolexpr(expr, csemap=csemap, reified=reified) if isinstance(flatexpr,_BoolVarImpl): # avoids unnecessary bv == bv or bv == ~bv assignments @@ -356,7 +360,7 @@ def get_or_make_var(expr, csemap=None): else: # normalize expr into a numexpr LHS, # then compute bounds and return (newintvar, LHS == newintvar) - (flatexpr, flatcons) = normalized_numexpr(expr, csemap=csemap) + (flatexpr, flatcons) = normalized_numexpr(expr, csemap=csemap, supported=supported, reified=reified) lb, ub = flatexpr.get_bounds() if not is_int(lb) or not is_int(ub): @@ -370,7 +374,7 @@ def get_or_make_var(expr, csemap=None): csemap[expr] = ivar return ivar, [flatexpr == ivar] + flatcons -def get_or_make_var_or_list(expr, csemap=None): +def get_or_make_var_or_list(expr, csemap=None, supported={}, reified=False): """ Like get_or_make_var() but also accepts and recursively transforms lists Used to convert arguments of globals """ @@ -378,13 +382,13 @@ def get_or_make_var_or_list(expr, csemap=None): if __is_flat_var_or_list(expr): return (expr,[]) elif is_any_list(expr): - flatvars, flatcons = zip(*[get_or_make_var(arg, csemap=csemap) for arg in expr]) + flatvars, flatcons = zip(*[get_or_make_var(arg, csemap=csemap, supported=supported, reified=reified) for arg in expr]) return (flatvars, [c for con in flatcons for c in con]) else: - return get_or_make_var(expr, csemap=csemap) + return get_or_make_var(expr, csemap=csemap, supported=supported, reified=reified) -def normalized_boolexpr(expr, csemap=None): +def normalized_boolexpr(expr, csemap=None, supported={}, reified=False): """ input is any Boolean (is_bool()) expression output are all 'flat normal form' Boolean expressions that can be 'reified', meaning that @@ -420,18 +424,18 @@ def normalized_boolexpr(expr, csemap=None): # apply De Morgan's transform for "implies" if expr.name == '->': # TODO, optimisation if args0 is an 'and'? - (lhs,lcons) = get_or_make_var(expr.args[0], csemap=csemap) + (lhs,lcons) = get_or_make_var(expr.args[0], csemap=csemap, reified=reified) # TODO, optimisation if args1 is an 'or'? - (rhs,rcons) = get_or_make_var(expr.args[1], csemap=csemap) + (rhs,rcons) = get_or_make_var(expr.args[1], csemap=csemap, reified=reified) return ((~lhs | rhs), lcons+rcons) if expr.name == 'not': - flatvar, flatcons = get_or_make_var(expr.args[0], csemap=csemap) + flatvar, flatcons = get_or_make_var(expr.args[0], csemap=csemap, reified=reified) return (~flatvar, flatcons) if not expr.has_subexpr(): return (expr, []) else: # one of the arguments is not flat, flatten all - flatvars, flatcons = zip(*[get_or_make_var(arg, csemap=csemap) for arg in expr.args]) + flatvars, flatcons = zip(*[get_or_make_var(arg, csemap=csemap, reified=reified) for arg in expr.args]) newexpr = Operator(expr.name, flatvars) return (newexpr, [c for con in flatcons for c in con]) @@ -450,19 +454,19 @@ def normalized_boolexpr(expr, csemap=None): lexpr, rexpr = rexpr, lexpr # ensure rhs is var - (rvar, rcons) = get_or_make_var(rexpr, csemap=csemap) + (rvar, rcons) = get_or_make_var(rexpr, csemap=csemap, reified=reified) # LHS: check if Boolexpr == smth: if (exprname == '==' or exprname == '!=') and lexpr.is_bool(): # this is a reified constraint, so lhs must be var too to be in normal form - (lhs, lcons) = get_or_make_var(lexpr, csemap=csemap) + (lhs, lcons) = get_or_make_var(lexpr, csemap=csemap, reified=reified) if expr.name == '!=' and rvar.is_bool(): # != not needed, negate RHS variable rvar = ~rvar exprname = '==' else: # other cases: LHS is numexpr - (lhs, lcons) = normalized_numexpr(lexpr, csemap=csemap) + (lhs, lcons) = normalized_numexpr(lexpr, csemap=csemap, supported=supported, reified=reified) return (Comparison(exprname, lhs, rvar), lcons+rcons) @@ -483,7 +487,7 @@ def normalized_boolexpr(expr, csemap=None): return (newexpr, [c for con in flatcons for c in con]) -def normalized_numexpr(expr, csemap=None): +def normalized_numexpr(expr, csemap=None, supported={}, reified=False): """ all 'flat normal form' numeric expressions... @@ -507,17 +511,17 @@ def normalized_numexpr(expr, csemap=None): elif expr.is_bool(): # unusual case, but its truth-value is a valid numexpr # so reify and return the boolvar - return get_or_make_var(expr, csemap=csemap) + return get_or_make_var(expr, csemap=csemap, reified=reified) # rewrite const*a into a weighted sum, so it can be used as objective elif expr.name == "mul" and getattr(expr, "is_lhs_num", False): w, e = expr.args - return normalized_numexpr(Operator("wsum", ([w], [e])), csemap=csemap) + return normalized_numexpr(Operator("wsum", ([w], [e])), csemap=csemap, supported=supported, reified=reified) elif isinstance(expr, Operator): # rewrite -a into a weighted sum, so it can be used as objective if expr.name == '-': - return normalized_numexpr(Operator("wsum", _wsum_make(expr)), csemap=csemap) + return normalized_numexpr(Operator("wsum", _wsum_make(expr)), csemap=csemap, supported=supported, reified=reified) if not expr.has_subexpr(): return (expr, []) @@ -529,7 +533,8 @@ def normalized_numexpr(expr, csemap=None): we = [_wsum_make(a) for a in expr.args] w = [wi for w,_ in we for wi in w] e = [ei for _,e in we for ei in e] - return normalized_numexpr(Operator("wsum", (w,e)), csemap=csemap) + return normalized_numexpr(Operator("wsum", (w,e)), csemap=csemap, supported=supported, reified=reified) + # wsum needs special handling because expr.args is a tuple of which only 2nd one has exprs if expr.name == 'wsum': @@ -550,13 +555,17 @@ def normalized_numexpr(expr, csemap=None): i = i+1 # now flatten the resulting subexprs - flatvars, flatcons = map(list, zip(*[get_or_make_var(arg, csemap=csemap) for arg in sub_exprs])) # also bool, reified... + flatvars, flatcons = map(list, zip(*[get_or_make_var(arg, csemap=csemap, reified=reified) for arg in sub_exprs])) # also bool, reified... newexpr = Operator(expr.name, (weights, flatvars)) return (newexpr, [c for con in flatcons for c in con]) else: # generic operator # recursively flatten all children - flatvars, flatcons = zip(*[get_or_make_var(arg, csemap=csemap) for arg in expr.args]) + # TODO make this some isinstance _IntVarImpl i/o hasattr? + flatvars, flatcons = zip(*( + normalized_numexpr(arg, csemap=csemap, supported=supported, reified=reified) + if not reified and hasattr(arg, "name") and arg.name in supported else + get_or_make_var(arg, csemap=csemap, reified=reified) for arg in expr.args)) newexpr = Operator(expr.name, flatvars) return (newexpr, [c for con in flatcons for c in con]) @@ -568,7 +577,11 @@ def normalized_numexpr(expr, csemap=None): return (expr, []) else: # recursively flatten all children - flatvars, flatcons = zip(*[get_or_make_var_or_list(arg, csemap=csemap) for arg in expr.args]) + flatvars, flatcons = zip(*[ + normalized_numexpr(arg, supported=supported, reified=reified) + if not reified and hasattr(arg, "name") and arg.name in supported + else get_or_make_var_or_list(arg, csemap=csemap, supported=supported) + for arg in expr.args]) # take copy, replace args newexpr = copy.copy(expr) # shallow or deep? currently shallow diff --git a/cpmpy/transformations/linearize.py b/cpmpy/transformations/linearize.py index 4a97b2b87..5424c1357 100644 --- a/cpmpy/transformations/linearize.py +++ b/cpmpy/transformations/linearize.py @@ -77,7 +77,7 @@ from ..expressions.core import Comparison, Expression, Operator, BoolVal from ..expressions.globalconstraints import GlobalConstraint, DirectConstraint, AllDifferent from ..expressions.globalfunctions import GlobalFunction, Element -from ..expressions.utils import is_bool, is_num, is_int, eval_comparison, get_bounds, is_true_cst, is_false_cst +from ..expressions.utils import is_bool, is_num, is_int, eval_comparison, get_bounds, is_true_cst, is_false_cst, is_boolexpr from ..expressions.variables import _BoolVarImpl, boolvar, NegBoolView, _NumVarImpl from .int2bool import _encode_int_var @@ -100,9 +100,10 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=Fal for cpm_expr in lst_of_expr: # Boolean literals are handled as trivial linears or unit clauses depending on `supported` if isinstance(cpm_expr, _BoolVarImpl): + # TODO gurobi specifically cannot do reified or's if "or" in supported: # post clause explicitly (don't use cp.any, which will just return the BoolVar) - newlist.append(Operator("or", [cpm_expr])) + newlist.append(cpm_expr) elif isinstance(cpm_expr, NegBoolView): # might as well remove the negation newlist.append(sum([~cpm_expr]) <= 0) @@ -128,8 +129,11 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=Fal f"calling `linearize_constraint`" if isinstance(cond, _BoolVarImpl) and isinstance(sub_expr, _BoolVarImpl): - # shortcut for BV -> BV, convert to disjunction and apply linearize on it - newlist.append(1 * cond + -1 * sub_expr <= 0) + # shortcut for BV -> BV, convert to disjunction and linearize it (if unsupported) + if "->" in supported: + newlist.append(cond.implies(cp.all(linearize_constraint([sub_expr], supported=supported, reified=True, csemap=csemap)))) + else: + newlist.append(1 * cond + -1 * sub_expr <= 0) # BV -> LinExpr elif isinstance(cond, _BoolVarImpl): @@ -186,6 +190,11 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=Fal elif isinstance(cpm_expr, Comparison): lhs, rhs = cpm_expr.args + # BV == and([a,b,c]) + if cpm_expr.name == "==" and hasattr(rhs, "name") and rhs.name in supported: + newlist.append(cpm_expr) + continue + if lhs.name == "sub": # convert to wsum lhs = Operator("wsum", [[1, -1], [lhs.args[0], lhs.args[1]]]) @@ -211,10 +220,11 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=Fal if lhs.name == "sum" and len(lhs.args) == 1 and isinstance(lhs.args[0], _BoolVarImpl) and "or" in supported: # very special case, avoid writing as sum of 1 argument new_expr = simplify_boolean([eval_comparison(cpm_expr.name,lhs.args[0], rhs)]) + new_expr = linearize_constraint(new_expr, supported=supported, reified=reified, csemap=csemap) assert len(new_expr) == 1 if isinstance(new_expr[0], BoolVal) and new_expr[0].value() is True: continue # skip or([BoolVal(True)]) - newlist.append(Operator("or", new_expr)) + newlist.extend(new_expr) continue @@ -284,6 +294,24 @@ def linearize_constraint(lst_of_expr, supported={"sum","wsum","->"}, reified=Fal return newlist +def only_positive_bv_args(args, csemap=None): + # other operators in comparison such as "min", "max" + nbv_sel = [isinstance(a, NegBoolView) for a in args] + new_cons = [] + if any(nbv_sel): + new_args = [] + for i, nbv in enumerate(nbv_sel): + if nbv: + aux = cp.boolvar() # TODO not added to CSE?? + new_args.append(aux) + new_cons += [aux + args[i]._bv == 1] # aux == 1 - arg._bv + else: + new_args.append(args[i]) + + return new_args, new_cons + else: + return args, [] + def only_positive_bv(lst_of_expr, csemap=None): """ Replaces :class:`~cpmpy.expressions.comparison.Comparison` containing :class:`~cpmpy.expressions.variables.NegBoolView` with equivalent expression using only :class:`~cpmpy.expressions.variables.BoolVar`. @@ -294,32 +322,43 @@ def only_positive_bv(lst_of_expr, csemap=None): newlist = [] for cpm_expr in lst_of_expr: - if isinstance(cpm_expr, Comparison): + if isinstance(cpm_expr, Operator) and cpm_expr.name in {"or", "and"}: + new_args, new_cons = only_positive_bv_args(cpm_expr.args) + cpm_expr_ = copy.copy(cpm_expr) + cpm_expr_.update_args(new_args) + newlist.extend([cpm_expr_] + new_cons) + elif isinstance(cpm_expr, Comparison): lhs, rhs = cpm_expr.args new_lhs = lhs - new_cons = [] - if isinstance(lhs, _NumVarImpl) or lhs.name in {"sum","wsum"}: + new_cons = [] + if (isinstance(lhs, _NumVarImpl) or lhs.name in {"sum","wsum"}) and not is_boolexpr(rhs): new_lhs, const = only_positive_bv_wsum_const(lhs) rhs -= const + elif isinstance(lhs, _BoolVarImpl): + (new_lhs,), new_cons_ = only_positive_bv_args([lhs]) + new_cons += new_cons_ + else: + new_args, new_cons_ = only_positive_bv_args(lhs.args) + new_lhs = copy.copy(lhs) + new_lhs.update_args(new_args) + new_cons += new_cons_ + + if isinstance(rhs, _BoolVarImpl): + new_args, new_cons_ = only_positive_bv_args([rhs]) + new_rhs = copy.copy(rhs) + new_rhs.update_args(new_args) + new_cons += new_cons_ + elif is_boolexpr(rhs): + new_args, new_cons_ = only_positive_bv_args(rhs.args) + new_rhs = copy.copy(rhs) + new_rhs.update_args(new_args) + new_cons += new_cons_ else: - # other operators in comparison such as "min", "max" - nbv_sel = [isinstance(a, NegBoolView) for a in lhs.args] - if any(nbv_sel): - new_args = [] - for i, nbv in enumerate(nbv_sel): - if nbv: - aux = cp.boolvar() - new_args.append(aux) - new_cons += [aux + lhs.args[i]._bv == 1] # aux == 1 - arg._bv - else: - new_args.append(lhs.args[i]) - - new_lhs = copy.copy(lhs) - new_lhs.update_args(new_args) - - if new_lhs is not lhs: - newlist.append(eval_comparison(cpm_expr.name, new_lhs, rhs)) + new_rhs = rhs + + if new_lhs is not lhs or new_rhs is not rhs: + newlist.append(eval_comparison(cpm_expr.name, new_lhs, new_rhs)) newlist += new_cons # already linear else: newlist.append(cpm_expr) @@ -333,9 +372,7 @@ def only_positive_bv(lst_of_expr, csemap=None): subexpr = only_positive_bv([subexpr], csemap=csemap) newlist += [cond.implies(expr) for expr in subexpr] - elif isinstance(cpm_expr, _BoolVarImpl): - raise ValueError(f"Unreachable: unexpected Boolean literal (`_BoolVarImpl`) in expression {cpm_expr}, perhaps `linearize_constraint` was not called before this `only_positive_bv `call") - elif isinstance(cpm_expr, (GlobalConstraint, BoolVal, DirectConstraint)): + elif isinstance(cpm_expr, (GlobalConstraint, BoolVal, DirectConstraint, _BoolVarImpl)): newlist.append(cpm_expr) else: raise Exception(f"{cpm_expr} is not linear or is not supported. Please report on github") diff --git a/cpmpy/transformations/reification.py b/cpmpy/transformations/reification.py index 2d65c14ab..efb309f39 100644 --- a/cpmpy/transformations/reification.py +++ b/cpmpy/transformations/reification.py @@ -30,6 +30,7 @@ from .negation import recurse_negation def only_bv_reifies(constraints, csemap=None): + """Transforms all reifications to ``BV -> BE`` or ``BV == BE``""" newcons = [] for cpm_expr in constraints: @@ -52,7 +53,7 @@ def only_bv_reifies(constraints, csemap=None): newcons.append(cpm_expr) return newcons -def only_implies(constraints, csemap=None): +def only_implies(constraints, csemap=None, is_supported=None): """ Transforms all reifications to ``BV -> BE`` form @@ -73,7 +74,9 @@ def only_implies(constraints, csemap=None): for cpm_expr in constraints: # Operators: check BE -> BV - if cpm_expr.name == '->' and cpm_expr.args[1].name == '==': + if is_supported and is_supported(cpm_expr): + newcons.append(cpm_expr) + elif cpm_expr.name == '->' and cpm_expr.args[1].name == '==': a0,a1 = cpm_expr.args if a1.args[0].is_bool() and a1.args[1].is_bool(): # BV0 -> BV2 == BV3 :: BV0 -> (BV2->BV3 & BV3->BV2) @@ -172,10 +175,20 @@ def reify_rewrite(constraints, supported=frozenset(), csemap=None): else: # reification, check for rewrite boolexpr = cpm_expr.args[boolexpr_index] if isinstance(boolexpr, Operator): - # Case 1, BE is Operator (and, or, ->) - # assume supported, return as is - newcons.append(cpm_expr) - # could actually rewrite into list of clauses like to_cnf() does... not for here + if boolexpr.name in supported or cpm_expr.name == "==": + # Case 1a, BE is Operator (and, or, ->) + newcons.append(cpm_expr) + # could actually rewrite into list of clauses like to_cnf() does... not for here + elif cpm_expr.name == "->": + # Case 1b, BE is an unflattened expression (TODO duplicated from below) + # We have BV -> BE, create BV -> auxvar, auxvar == BE + (auxvar, cons) = get_or_make_var(boolexpr, csemap=csemap) + newcons += cons + reifexpr = copy.copy(cpm_expr) + args = list(reifexpr.args) + args[boolexpr_index] = auxvar + reifexpr.update_args(tuple(args)) + newcons.append(reifexpr) elif isinstance(boolexpr, GlobalConstraint): # Case 2, BE is a GlobalConstraint # replace BE by its decomposition, then flatten diff --git a/tests/test_constraints.py b/tests/test_constraints.py index c733a3bb3..2089f1e9e 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -17,12 +17,13 @@ # also add exclusions to the 3 EXCLUDE_* below as needed SOLVERNAMES = [name for name, solver in SolverLookup.base_solvers() if solver.supported()] ALL_SOLS = False # test whether all solutions returned by the solver satisfy the constraint +# ALL_SOLS = True # test whether all solutions returned by the solver satisfy the constraint # Exclude some global constraints for solvers NUM_GLOBAL = { "AllEqual", "AllDifferent", "AllDifferentExcept0", "AllDifferentExceptN", "AllEqualExceptN", - "GlobalCardinalityCount", "InDomain", "Inverse","Circuit", + "GlobalCardinalityCount", "InDomain", "Inverse", "Circuit", "Table", 'NegativeTable', "ShortTable", "Regular", "Increasing", "IncreasingStrict", "Decreasing", "DecreasingStrict", "Precedence", "Cumulative", "NoOverlap", "CumulativeOptional", "NoOverlapOptional", @@ -37,7 +38,6 @@ EXCLUDE_GLOBAL = { "pysdd": NUM_GLOBAL | {"Xor"}, "minizinc": {"IncreasingStrict"}, # bug #813 reported on libminizinc - } # Exclude certain operators for solvers. @@ -286,6 +286,14 @@ def global_functions(solver): else: yield cls(NUM_ARGS) +def generate_cases(solver): + yield cp.boolvar(name="x") >= 0 # issue #736 + x, y = cp.intvar(1, 3,shape=2, name=["x", "y"]) + yield x ** 2 - 2*x*y + y**2 <= 3 + + # p, q = cp.intvar(shape=2, name=["p", "q"]) + # yield p + # yield ((cp.boolvar(name="x") >= 0) | (cp.boolvar(name="y") >= 0)) # issue #736 def reify_imply_exprs(solver): """ @@ -306,48 +314,49 @@ def reify_imply_exprs(solver): def verify(cons): - assert argval(cons) - assert cons.value() - -@pytest.mark.generate_constraints.with_args(bool_exprs) + from cpmpy.transformations.get_variables import get_variables + vars_ = get_variables(cons) + assignment = {v.name: v.value() for v in sorted(vars_, key=lambda v: v.name)} + assert argval(cons), f"argval failed for {cons} with assignment {assignment}" + assert cons.value(), f"value() failed for {cons} with assignment {assignment}" + +def all_constraints(solver): + """Combined generator for all constraint types.""" + # yield from bool_exprs(solver) + # yield from comp_constraints(solver) + # yield from reify_imply_exprs(solver) + yield from generate_cases(solver) + +@pytest.mark.generate_constraints.with_args(all_constraints) @skip_on_missing_pblib(skip_on_exception_only=True) -def test_bool_constraints(solver, constraint): +def test_constraints(solver, constraint): """ - Tests boolean constraint by posting it to the solver and checking the value after solve. + Tests constraint by posting it to the solver and checking the value after solve. """ if ALL_SOLS: - n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display=lambda: verify(constraint)) + n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display=lambda: verify(constraint), solution_limit=100) assert n_sols >= 1 else: assert SolverLookup.get(solver, Model(constraint)).solve() assert argval(constraint) assert constraint.value() -@pytest.mark.generate_constraints.with_args(comp_constraints) -@skip_on_missing_pblib(skip_on_exception_only=True) -def test_comparison_constraints(solver, constraint): - """ - Tests comparison constraint by posting it to the solver and checking the value after solve. - """ - if ALL_SOLS: - n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display= lambda: verify(constraint)) - assert n_sols >= 1 - else: - assert SolverLookup.get(solver,Model(constraint)).solve() - assert argval(constraint) - assert constraint.value() - +if __name__ == "__main__": + solver = None # Use None for no solver-specific exclusions + + generators = [ + ("Boolean expressions", bool_exprs), + ("Comparison constraints", comp_constraints), + ("Global constraints", global_constraints), + ("Global functions", global_functions), + ("Reify/imply expressions", reify_imply_exprs), + ] + + for name, gen in generators: + print(f"\n{'='*60}") + print(f"{name}") + print('='*60) + for i, expr in enumerate(gen(solver)): + model = Model(expr) + print(f"{i+1}. {model}") -@pytest.mark.generate_constraints.with_args(reify_imply_exprs) -@skip_on_missing_pblib(skip_on_exception_only=True) -def test_reify_imply_constraints(solver, constraint): - """ - Tests boolean expression by posting it to solver and checking the value after solve. - """ - if ALL_SOLS: - n_sols = SolverLookup.get(solver, Model(constraint)).solveAll(display=lambda: verify(constraint)) - assert n_sols >= 1 - else: - assert SolverLookup.get(solver, Model(constraint)).solve() - assert argval(constraint) - assert constraint.value() diff --git a/tests/test_gurobi.py b/tests/test_gurobi.py new file mode 100644 index 000000000..170749819 --- /dev/null +++ b/tests/test_gurobi.py @@ -0,0 +1,252 @@ +""" +Tests for Gurobi solver transformations and expression tree support. + +These tests verify that CPMpy correctly transforms constraints to Gurobi's +expression tree format by comparing the generated LP file output. + +https://docs.gurobi.com/projects/optimizer/en/current/features/nonlinear.html +""" + +import pytest +import tempfile +import cpmpy as cp +from cpmpy.solvers.gurobi import CPM_gurobi +from cpmpy.expressions.variables import _IntVarImpl, _BoolVarImpl + + +def get_lp_string(solver): + """Write the Gurobi model to LP format and return as string.""" + solver.grb_model.update() + with tempfile.NamedTemporaryFile(suffix=".lp", delete=False, mode="w") as f: + solver.grb_model.write(f.name) + with open(f.name) as rf: + return rf.read() + + +def extract_constraints(lp_string): + """Extract constraints from 'Subject To' and 'General Constraints' sections as a list.""" + lines = lp_string.split("\n") + in_section = False + constraints = [] + for line in lines: + if line.strip() in ("Subject To", "General Constraints"): + in_section = True + continue + if line.strip() in ("Bounds", "Binaries", "Generals", "End"): + in_section = False + if in_section and line.strip(): + constraints.append(line.strip()) + return constraints + + +def reset_counters(): + _IntVarImpl.counter = 0 + _BoolVarImpl.counter = 0 + + +def expression_tree_cases(): + for c in expression_tree_cases_(): + reset_counters() + yield c + + +def expression_tree_cases_(): + """Generator yielding (name, constraint_func, expected_lp) tuples.""" + + x, y, z = [cp.intvar(-2, 2, name=name) for name in "xyz"] + p, q, r = [cp.boolvar(name=name) for name in "pqr"] + + yield ( + "BV", + p, + ["p"], + ["R0: p >= 1"], + ) + + yield ( + "True", + cp.BoolVal(True), + ["boolval(True)"], + ["R0: C0 = 1"], + ) + + yield ( + "False", + cp.BoolVal(False), + ["boolval(False)"], + ["R0: C0 = 0"], + ) + + yield ( + "pow", + x**2 + y == 9, + ["(pow(x,2)) + (y) == 9"], + ["qc0: y + [ x ^2 ] = 9"], + ) + + """Positive implications""" + yield ( + "positive_implication", + p.implies(x + y <= 3), + ["(p) -> ((x) + (y) <= 3)"], + ["GC0: p = 1 -> x + y <= 3"], + ) + + """Negative implications""" + yield ( + "negative_implication", + (~p).implies(x + y <= 3), + ["(~p) -> ((x) + (y) <= 3)"], + ["GC0: p = 0 -> x + y <= 3"], + ) + + """While NL constraint pow can be a expression tree node, it cannot be reified""" + yield ( + "imp_quad", + p.implies(x * y == 3), + ["((x) * (y)) == (IV0)", "(p) -> (sum(IV0) == 3)"], + ["qc0: IV0 + [ - x * y ] = 0", "GC0: p = 1 -> IV0 = 3"], + ) + + yield ( + "pow_bool", + p**2 + q == 2, + ["(pow(p,2)) + (q) == 2"], + ["qc0: q + [ p ^2 ] = 2"], + ) + + yield ( + "multiplication", + z + x * y == 12, + ["(z) + ((x) * (y)) == 12"], + ["qc0: z + [ x * y ] = 12"], + ) + + yield ( + "maximum", + z + cp.Maximum([x, y]) == 12, + ["(z) + (IV0) == 12", "(max(x,y)) == (IV0)"], + ["R0: z + IV0 = 12", "GC0: IV0 = MAX ( x , y )"], + ) + + yield ( + "nested", + z + (x - 3) * ((-y) ** 2) - 3 == 12, + ["(z) + (((x) + -3) * (pow(sum([-1] * [y]),2))) == 15"], + [ + "\\ C3 = z + (sqr(y) * (-3 + x))", + "GC0: C3 = NL : ( PLUS , -1 , -1 ) ( VARIABLE , z , 0 )", + # TODO not totally clean MULTIPLY node? + "( MULTIPLY , -1 , 0 ) ( SQUARE , -1 , 2 ) ( VARIABLE , y , 3 )", + "( PLUS , -1 , 2 ) ( CONSTANT , -3 , 5 ) ( VARIABLE , x , 5 )", + ], + ) + + # # TODO needlessly reifying + # yield ( + # "subtract", + # -(x * y) == 12, + # ["(z) + ((x) * (y)) == 12"], + # ["qc0: z + [ x * y ] = 12"], + # ) + + # TODO divide (semantic may be slightly different from gurobi?) + + yield ( + "abs", + cp.Abs(x) + y == 3, + ["(IV0) + (y) == 3", "(abs(x)) == (IV0)"], + ["R0: IV0 + y = 3", "GC0: IV0 = ABS ( x )"], + ) + + """Mul is supported in expression tree, but not abs""" + yield ( + "abs_in_mul", + cp.Abs(x) * y + z == 3, + ["((IV0) * (y)) + (z) == 3", "(abs(x)) == (IV0)"], + ["qc0: z + [ IV0 * y ] = 3", "GC0: IV0 = ABS ( x )"], + ) + + yield ( + "mul_in_abs", + cp.Abs(x * y) + z == 3, + ["(IV1) + (z) == 3", "(abs(IV0)) == (IV1)", "((x) * (y)) == (IV0)"], + ["R0: IV1 + z = 3", "qc0: IV0 + [ - x * y ] = 0", "GC0: IV1 = ABS ( IV0 )"], + ) + + # TODO keep as operator? + yield ( + "minus_in", + z * (x - y) == 1, + ["(z) * (sum([1, -1] * [x, y])) == 1"], + ["qc0: [ z * x - z * y ] = 1"], # TODO not sure how it did this, but happy with it + ) + + yield ( + "minus_out", + z * -(x + y) == 1, + ["(z) * (sum([-1, -1] * [x, y])) == 1"], + ["qc0: [ - z * x - z * y ] = 1"], + ) + + yield ( + "reification", + z * (x == 2) == 1, + [ + "(z) * (BV0) == 1", + "(BV0) -> (sum(x) == 2)", + "(~BV0) -> (sum([1, -1] * [x, BV1]) <= 1)", + "(~BV0) -> (sum([1, -5] * [x, BV1]) >= -2)", + "(BV0) -> (sum([-1] * [BV1]) >= 0)", # TODO ? + ], + [ + "qc0: [ z * BV0 ] = 1", + "GC0: BV0 = 1 -> x = 2", + "GC1: BV0 = 0 -> x - BV1 <= 1", + "GC2: BV0 = 0 -> x - 5 BV1 >= -2", + "GC3: BV0 = 1 -> - BV1 >= 0", + ], + ) + + yield ( + "disjunction", + p | q, + ["(p) or (q)"], + ["GC0: C2 = OR ( p , q )"], + ) + + # yield ( + # "conjunction", + # p & q, + # ["(p) and (q)"], + # ["R0: BV0 >= 1", "GC0: BV0 = OR ( p , q )"], + # ) + + yield ( + "conjunction_in_disjunction", + (p | (q & r)), + ["(p) or (BV0)", "(BV0) == ((q) and (r))"], + ["GC0: C2 = OR ( p , BV0 )", "GC1: BV0 = AND ( q , r )"], + ) + + +@pytest.mark.requires_solver("gurobi") +@pytest.mark.parametrize( + "name,constraint,expected_tf,expected_lp", list(expression_tree_cases()), ids=[c[0] for c in expression_tree_cases()] +) +def test_gurobi_expression_tree(name, constraint, expected_tf, expected_lp): + """Test that Gurobi transformation generates expected LP output.""" + reset_counters() + solver = CPM_gurobi() + transformed = [str(c) for c in CPM_gurobi().transform(constraint)] + print("TF", ", ".join(transformed)) + reset_counters() + + solver += constraint + + lp = get_lp_string(solver) + print(lp) + + constraints = extract_constraints(lp) + assert transformed == expected_tf, f"Generated transformation:\n{transformed}" + assert constraints == expected_lp, f"Generated constraints:\n{constraints}\n\nFull LP:\n{lp}" diff --git a/tests/test_trans_linearize.py b/tests/test_trans_linearize.py index 358e1f332..0d3144cf7 100644 --- a/tests/test_trans_linearize.py +++ b/tests/test_trans_linearize.py @@ -34,7 +34,7 @@ def test_linearize(self): # implies cons = linearize_constraint([a.implies(b)])[0] - assert "sum([1, -1] * [a, b]) <= 0" == str(cons) + assert "(a) -> (b >= 1)" == str(cons) def test_bug_168(self): from cpmpy.solvers import CPM_gurobi @@ -69,8 +69,8 @@ def test_constraint(self): assert str(linearize_constraint([a | b | c])) == "[sum(a, b, c) >= 1]" assert str(linearize_constraint([a | b | (~c)])) == "[sum(a, b, ~c) >= 1]" # test implies - assert str(linearize_constraint([a.implies(b)])) == "[sum([1, -1] * [a, b]) <= 0]" - assert str(linearize_constraint([a.implies(~b)])) == "[sum([1, -1] * [a, ~b]) <= 0]" + assert str(linearize_constraint([a.implies(b)])) == "[(a) -> (b >= 1)]" + assert str(linearize_constraint([a.implies(~b)])) == "[(a) -> (b <= 0)]" assert str(linearize_constraint([a.implies(x+y+z >= 0)])) == str([]) assert str(linearize_constraint([a.implies(x+y+z >= 2)])) == "[(a) -> (sum(x, y, z) >= 2)]" assert str(linearize_constraint([a.implies(x+y+z > 0)])) == "[(a) -> (sum(x, y, z) >= 1)]" @@ -87,7 +87,7 @@ def test_constraint(self): c1, c2, c3 = linearize_constraint([a.implies(x != y)]) assert str(c1) == "(a) -> (sum([1, -1, -6] * [x, y, BV4]) <= -1)" assert str(c2) == "(a) -> (sum([1, -1, -6] * [x, y, BV4]) >= -5)" - assert str(c3) == "sum([1, -1] * [~a, ~BV4]) <= 0" + assert str(c3) == "(~a) -> (BV4 <= 0)" def test_single_boolvar(self): @@ -95,8 +95,8 @@ def test_single_boolvar(self): p = cp.boolvar(name="p") assert str([p >= 1]) == str(linearize_constraint([p])) assert str([p <= 0]) == str(linearize_constraint([~p])) - assert str([Operator("or", [p])]) == str(linearize_constraint([p], supported={"or"})) - assert str([Operator("or", [~p])]) == str(linearize_constraint([~p], supported={"or"})) + assert str([p]) == str(linearize_constraint([p], supported={"or"})) + assert str([~p]) == str(linearize_constraint([~p], supported={"or"})) def test_neq(self): # not equals is a tricky constraint to linearize, do some extra tests on it here diff --git a/tests/test_transf_reif.py b/tests/test_transf_reif.py index 6bc206d67..ebc5a7b1e 100644 --- a/tests/test_transf_reif.py +++ b/tests/test_transf_reif.py @@ -82,7 +82,7 @@ def test_reif_rewrite(self): rv = cp.boolvar(name="rv") arr = cp.cpm_array([0,1,2]) - f = lambda expr : str(reify_rewrite(flatten_constraint(expr))) + f = lambda expr : str(reify_rewrite(flatten_constraint(expr), supported={"or"})) fd = lambda expr : str(reify_rewrite(flatten_constraint(decompose_in_tree(expr))))