diff --git a/cpmpy/expressions/variables.py b/cpmpy/expressions/variables.py index ede58a4b7..cbb48f55f 100644 --- a/cpmpy/expressions/variables.py +++ b/cpmpy/expressions/variables.py @@ -68,6 +68,7 @@ _BV_PREFIX = "BV" _IV_PREFIX = "IV" _VAR_ERR = f"Variable names starting with {_IV_PREFIX} or {_BV_PREFIX} are reserved for internal use only, chose a different name" +_VAR_STRICT_NAME_CHECK = True def BoolVar(shape=1, name=None): """ @@ -790,5 +791,58 @@ def _genname(basename, idxs): return f"{basename}[{stridxs}]" # "[,,...]" def _is_invalid_name(name): - return name.startswith(_IV_PREFIX) or name.startswith(_BV_PREFIX) + """ + Check if a variable name is invalid. + + In 'strict' mode, the name is invalid if it starts with {_IV_PREFIX} or {_BV_PREFIX}. + In 'non-strict' mode, the name is invalid if it starts with {_IV_PREFIX} or {_BV_PREFIX} + and the variables' counter is greater than the index, i.e. the name is already in use. + Toggle the strict mode with `_enable_strict_variable_name_check()` and `_disable_strict_variable_name_check()`, + or use the context manager `_ignore_strict_variable_name_check()`. + """ + if name.startswith(_IV_PREFIX): + if _VAR_STRICT_NAME_CHECK: + return True + else: + id = int(name[len(_IV_PREFIX):]) + if _IntVarImpl.counter > id: + return True # TODO: better error message + else: + return False + + elif name.startswith(_BV_PREFIX): + if _VAR_STRICT_NAME_CHECK: + return True + else: + id = int(name[len(_BV_PREFIX):]) + if _BoolVarImpl.counter > id: + return True # TODO: better error message + else: + return False + + else: + return False + +def _enable_strict_variable_name_check(): + global _VAR_STRICT_NAME_CHECK + _VAR_STRICT_NAME_CHECK = True + +def _disable_strict_variable_name_check(): + global _VAR_STRICT_NAME_CHECK + _VAR_STRICT_NAME_CHECK = False + + +def _ignore_strict_variable_name_check(): + """ + Context manager to temporarily disable strict variable name check. + """ + class IgnoreStrictVariableNameCheck: + def __enter__(self): + _disable_strict_variable_name_check() + def __exit__(self, exc_type, exc_value, traceback): + _enable_strict_variable_name_check() + # _update_variable_counters() # TODO: add automatic support for this later (different PR) + return False # propagate exceptions + + return IgnoreStrictVariableNameCheck() \ No newline at end of file diff --git a/cpmpy/model.py b/cpmpy/model.py index eebaa251d..a807c1250 100644 --- a/cpmpy/model.py +++ b/cpmpy/model.py @@ -25,6 +25,7 @@ Model """ +from __future__ import annotations import copy import warnings from typing import Optional @@ -40,6 +41,33 @@ import pickle + +def _update_variable_counters(model: Model): + from cpmpy.transformations.get_variables import get_variables_model # avoid circular import + from cpmpy.expressions.variables import _BoolVarImpl, _IntVarImpl, _BV_PREFIX, _IV_PREFIX # avoid circular import + + vs = get_variables_model(model) + bv_counter = 0 + iv_counter = 0 + for v in vs: + if v.name.startswith(_BV_PREFIX): + try: + bv_counter = max(bv_counter, int(v.name[2:])+1) + except: # When name starts with _BV_PREFIX but is not a valid integer (user created name), ignore + pass + elif v.name.startswith(_IV_PREFIX): + try: + iv_counter = max(iv_counter, int(v.name[2:])+1) + except: # When name starts with _IV_PREFIX but is not a valid integer (user created name), ignore + pass + + if (_BoolVarImpl.counter > 0 and bv_counter > 0) or \ + (_IntVarImpl.counter > 0 and iv_counter > 0): + warnings.warn(f"Model contains auxiliary {_IV_PREFIX}*/{_BV_PREFIX}* variables with the same name as already created. Only add expressions created AFTER loadig this model to avoid issues with duplicate variables.") + _BoolVarImpl.counter = max(_BoolVarImpl.counter, bv_counter) + _IntVarImpl.counter = max(_IntVarImpl.counter, iv_counter) + + class Model(object): """ CPMpy Model object, contains the constraint and objective expressions @@ -284,28 +312,7 @@ def from_file(fname): with open(fname, "rb") as f: m = pickle.load(f) # bug 158, we should increase the boolvar/intvar counters to avoid duplicate names - from cpmpy.transformations.get_variables import get_variables_model # avoid circular import - from cpmpy.expressions.variables import _BoolVarImpl, _IntVarImpl, _BV_PREFIX, _IV_PREFIX # avoid circular import - vs = get_variables_model(m) - bv_counter = 0 - iv_counter = 0 - for v in vs: - if v.name.startswith(_BV_PREFIX): - try: - bv_counter = max(bv_counter, int(v.name[2:])+1) - except: - pass - elif v.name.startswith(_IV_PREFIX): - try: - iv_counter = max(iv_counter, int(v.name[2:])+1) - except: - pass - - if (_BoolVarImpl.counter > 0 and bv_counter > 0) or \ - (_IntVarImpl.counter > 0 and iv_counter > 0): - warnings.warn(f"from_file '{fname}': contains auxiliary {_IV_PREFIX}*/{_BV_PREFIX}* variables with the same name as already created. Only add expressions created AFTER loadig this model to avoid issues with duplicate variables.") - _BoolVarImpl.counter = max(_BoolVarImpl.counter, bv_counter) - _IntVarImpl.counter = max(_IntVarImpl.counter, iv_counter) + _update_variable_counters(m) return m def copy(self):