From ce4da5724c9f14d57195780a6e21ca6634f65f59 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Tue, 24 Mar 2026 15:41:24 +0100 Subject: [PATCH 01/12] feat: initial setup for making literal wrapper around atomic --- .../constraints/src/constraints/boolean.rs | 28 +- pumpkin-crates/core/src/api/solver.rs | 4 +- .../core/src/engine/cp/assignments.rs | 87 ++++++ .../core/src/engine/cp/test_solver.rs | 2 +- .../domain_event_watch_list.rs | 29 ++ .../core/src/engine/notifications/mod.rs | 99 +++++- pumpkin-crates/core/src/engine/state.rs | 2 +- .../core/src/engine/variables/literal.rs | 284 ++++++++++++------ .../core/src/propagation/constructor.rs | 14 +- .../contexts/propagation_context.rs | 14 +- .../propagators/nogoods/nogood_propagator.rs | 1 + .../arithmetic/linear_less_or_equal.rs | 8 +- pumpkin-solver-py/src/model.rs | 12 - pumpkin-solver-py/src/variables.rs | 4 - .../flatzinc/compiler/collect_domains.rs | 3 +- .../flatzinc/compiler/context.rs | 3 +- .../compiler/create_search_strategy.rs | 5 +- .../compiler/define_variable_arrays.rs | 3 +- .../flatzinc/compiler/post_constraints.rs | 11 +- 19 files changed, 468 insertions(+), 145 deletions(-) diff --git a/pumpkin-crates/constraints/src/constraints/boolean.rs b/pumpkin-crates/constraints/src/constraints/boolean.rs index 32539f39b..b9ee45186 100644 --- a/pumpkin-crates/constraints/src/constraints/boolean.rs +++ b/pumpkin-crates/constraints/src/constraints/boolean.rs @@ -67,11 +67,11 @@ impl Constraint for BooleanLessThanOrEqual { } impl BooleanLessThanOrEqual { - fn create_domains(&self) -> Vec> { + fn create_domains(&self) -> Vec> { self.bools .iter() .enumerate() - .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) + .map(|(index, bool)| bool.scaled(self.weights[index])) .collect() } } @@ -85,7 +85,9 @@ struct BooleanEqual { impl Constraint for BooleanEqual { fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(); + let (domains, rhs_domain) = self.create_domains(); + + todo!(); equals(domains, 0, self.constraint_tag).post(solver) } @@ -95,19 +97,23 @@ impl Constraint for BooleanEqual { solver: &mut Solver, reification_literal: Literal, ) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(); + let (domains, rhs_domain) = self.create_domains(); + + todo!(); equals(domains, 0, self.constraint_tag).implied_by(solver, reification_literal) } } impl BooleanEqual { - fn create_domains(&self) -> Vec> { - self.bools - .iter() - .enumerate() - .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) - .chain(std::iter::once(self.rhs.scaled(-1))) - .collect() + fn create_domains(&self) -> (Vec>, AffineView) { + ( + self.bools + .iter() + .enumerate() + .map(|(index, bool)| bool.scaled(self.weights[index])) + .collect(), + self.rhs.scaled(-1), + ) } } diff --git a/pumpkin-crates/core/src/api/solver.rs b/pumpkin-crates/core/src/api/solver.rs index e5f866dd3..757cbef19 100644 --- a/pumpkin-crates/core/src/api/solver.rs +++ b/pumpkin-crates/core/src/api/solver.rs @@ -102,7 +102,7 @@ pub struct Solver { impl Default for Solver { fn default() -> Self { let satisfaction_solver = ConstraintSatisfactionSolver::default(); - let true_literal = Literal::new(Predicate::trivially_true().get_domain()); + let true_literal = Literal::new(Predicate::trivially_true()); Self { satisfaction_solver, true_literal, @@ -114,7 +114,7 @@ impl Solver { /// Creates a solver with the provided [`SolverOptions`]. pub fn with_options(solver_options: SolverOptions) -> Self { let satisfaction_solver = ConstraintSatisfactionSolver::new(solver_options); - let true_literal = Literal::new(Predicate::trivially_true().get_domain()); + let true_literal = Literal::new(Predicate::trivially_true()); Self { satisfaction_solver, true_literal, diff --git a/pumpkin-crates/core/src/engine/cp/assignments.rs b/pumpkin-crates/core/src/engine/cp/assignments.rs index d0dad662e..904d207f5 100644 --- a/pumpkin-crates/core/src/engine/cp/assignments.rs +++ b/pumpkin-crates/core/src/engine/cp/assignments.rs @@ -288,6 +288,15 @@ impl Assignments { self.is_domain_assigned(var).then(|| var.lower_bound(self)) } + pub(crate) fn get_assigned_value_at_trail_position( + &self, + var: &Var, + trail_position: usize, + ) -> Option { + self.is_domain_assigned_at_trail_position(var, trail_position) + .then(|| var.lower_bound(self)) + } + pub(crate) fn is_decision_predicate(&self, predicate: &Predicate) -> bool { let domain = predicate.get_domain(); if let Some(trail_position) = self.get_trail_position(predicate) @@ -368,6 +377,15 @@ impl Assignments { var.lower_bound(self) == var.upper_bound(self) } + pub(crate) fn is_domain_assigned_at_trail_position( + &self, + var: &Var, + trail_position: usize, + ) -> bool { + var.lower_bound_at_trail_position(self, trail_position) + == var.upper_bound_at_trail_position(self, trail_position) + } + /// Returns the index of the trail entry at which point the given predicate became true. /// In case the predicate is not true, then the function returns None. /// Note that it is not necessary for the predicate to be explicitly present on the trail, @@ -617,6 +635,66 @@ impl Assignments { Ok(update_took_place) } + /// Determines whether the provided [`Predicate`] holds at the provided trail position. In case + /// the predicate is not assigned yet (neither true nor false), returns None. + pub(crate) fn evaluate_predicate_at_trail_position( + &self, + predicate: Predicate, + trail_position: usize, + ) -> Option { + let domain_id = predicate.get_domain(); + let value = predicate.get_right_hand_side(); + + match predicate.get_predicate_type() { + PredicateType::LowerBound => { + if self.get_lower_bound_at_trail_position(domain_id, trail_position) >= value { + Some(true) + } else if self.get_upper_bound_at_trail_position(domain_id, trail_position) < value + { + Some(false) + } else { + None + } + } + PredicateType::UpperBound => { + if self.get_upper_bound_at_trail_position(domain_id, trail_position) <= value { + Some(true) + } else if self.get_lower_bound_at_trail_position(domain_id, trail_position) > value + { + Some(false) + } else { + None + } + } + PredicateType::NotEqual => { + if !self.is_value_in_domain_at_trail_position(domain_id, value, trail_position) { + Some(true) + } else if let Some(assigned_value) = + self.get_assigned_value_at_trail_position(&domain_id, trail_position) + { + // Previous branch concluded the value is not in the domain, so if the variable + // is assigned, then it is assigned to the not equals value. + pumpkin_assert_simple!(assigned_value == value); + Some(false) + } else { + None + } + } + PredicateType::Equal => { + if !self.is_value_in_domain_at_trail_position(domain_id, value, trail_position) { + Some(false) + } else if let Some(assigned_value) = + self.get_assigned_value_at_trail_position(&domain_id, trail_position) + { + pumpkin_assert_moderate!(assigned_value == value); + Some(true) + } else { + None + } + } + } + } + /// Determines whether the provided [`Predicate`] holds in the current state of the /// [`Assignments`]. In case the predicate is not assigned yet (neither true nor false), /// returns None. @@ -673,6 +751,15 @@ impl Assignments { .is_some_and(|truth_value| truth_value) } + pub(crate) fn is_predicate_satisfied_at_trail_position( + &self, + predicate: Predicate, + trail_position: usize, + ) -> bool { + self.evaluate_predicate_at_trail_position(predicate, trail_position) + .is_some_and(|truth_value| truth_value) + } + #[allow(unused, reason = "makes sense to have in this API")] pub(crate) fn is_predicate_falsified(&self, predicate: Predicate) -> bool { self.evaluate_predicate(predicate) diff --git a/pumpkin-crates/core/src/engine/cp/test_solver.rs b/pumpkin-crates/core/src/engine/cp/test_solver.rs index 4bbcdf795..368c0208b 100644 --- a/pumpkin-crates/core/src/engine/cp/test_solver.rs +++ b/pumpkin-crates/core/src/engine/cp/test_solver.rs @@ -86,7 +86,7 @@ impl TestSolver { pub fn new_literal(&mut self) -> Literal { let domain_id = self.new_variable(0, 1); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } pub fn new_propagator( diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs index 9d3220c2c..ff4dd07b8 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs @@ -3,10 +3,15 @@ use std::fmt::Display; use enumset::EnumSet; use enumset::EnumSetType; +use crate::basic_types::PredicateId; use crate::containers::KeyedVec; +use crate::engine::Assignments; +use crate::engine::TrailedValues; use crate::engine::notifications::NotificationEngine; use crate::engine::variables::DomainId; +use crate::predicates::Predicate; use crate::propagation::PropagatorVarId; +use crate::variables::Literal; #[derive(Default, Debug, Clone)] pub(crate) struct WatchListDomainEvents { @@ -22,6 +27,26 @@ pub(crate) struct WatchListDomainEvents { pub struct Watchers<'a> { propagator_var: PropagatorVarId, notification_engine: &'a mut NotificationEngine, + trailed_values: &'a mut TrailedValues, + assignments: &'a Assignments, +} + +impl<'a> Watchers<'a> { + pub(crate) fn watch_literal(&mut self, literal: Literal, events: EnumSet) { + self.notification_engine.watch_literal( + literal, + events, + self.propagator_var, + self.trailed_values, + self.assignments, + ) + } + + pub(crate) fn unwatch_predicate(&mut self, predicate: Predicate) { + let predicate_id = self.notification_engine.get_id(predicate); + self.notification_engine + .unwatch_predicate(predicate_id, self.propagator_var.propagator); + } } /// A description of the kinds of events that can happen on a domain variable. @@ -95,10 +120,14 @@ impl<'a> Watchers<'a> { pub(crate) fn new( propagator_var: PropagatorVarId, notification_engine: &'a mut NotificationEngine, + trailed_values: &'a mut TrailedValues, + assignments: &'a Assignments, ) -> Self { Watchers { propagator_var, notification_engine, + trailed_values, + assignments, } } diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index ba6a5e64e..1c1e7ba92 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -11,7 +11,9 @@ use enumset::EnumSet; pub(crate) use predicate_notification::PredicateNotifier; use crate::basic_types::PredicateId; +use crate::containers::HashMap; use crate::containers::KeyedVec; +use crate::containers::StorageKey; use crate::engine::Assignments; use crate::engine::PropagatorQueue; use crate::engine::TrailedValues; @@ -26,6 +28,7 @@ use crate::propagation::store::PropagatorStore; use crate::pumpkin_assert_extreme; use crate::pumpkin_assert_simple; use crate::variables::DomainId; +use crate::variables::Literal; #[derive(Debug, Clone)] pub(crate) struct NotificationEngine { @@ -38,6 +41,8 @@ pub(crate) struct NotificationEngine { watch_list_domain_events: WatchListDomainEvents, /// The watch list from predicates to propagators. pub(crate) watch_list_predicate_id: KeyedVec>, + // TODO: Should use direct hashing + pub(crate) literal_watch_list: HashMap)>, /// Events which have occurred since the last round of notifications have taken place events: EventSink, /// Backtrack events which have occurred since the last of backtrack notifications have taken @@ -50,6 +55,7 @@ impl Default for NotificationEngine { let mut result = Self { watch_list_domain_events: Default::default(), watch_list_predicate_id: Default::default(), + literal_watch_list: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: 0, events: Default::default(), @@ -194,6 +200,57 @@ impl NotificationEngine { // TODO: Can we remove the predicate from being tracked if it does not have watchers? } + pub(crate) fn watch_literal( + &mut self, + literal: Literal, + events: EnumSet, + propagator_var: PropagatorVarId, + trailed_values: &mut TrailedValues, + assignments: &Assignments, + ) { + let entry = self + .literal_watch_list + .entry(literal) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + + for event in events { + match event { + DomainEvent::Assign => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::LowerBound => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::UpperBound => { + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::Removal => {} + }; + } + } + pub(crate) fn watch_all_backtrack( &mut self, domain: DomainId, @@ -388,19 +445,43 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - for predicate_id in self.predicate_notifier.drain_satisfied_predicates() { + for predicate_id in self + .predicate_notifier + .drain_satisfied_predicates() + .collect::>() + { if let Some(watch_list) = self.watch_list_predicate_id.get(predicate_id) { let propagators_to_notify = watch_list.iter().copied(); for propagator_id in propagators_to_notify { - let mut context = NotificationContext::new(trailed_values, assignments); - - let propagator = &mut propagators[propagator_id]; - let enqueue_decision = - propagator.notify_predicate_id_satisfied(context.reborrow(), predicate_id); - - if enqueue_decision == EnqueueDecision::Enqueue { - propagator_queue.enqueue_propagator(propagator_id, propagator.priority()); + let predicate = self.predicate_notifier.get_predicate(predicate_id); + let literal = Literal::new(predicate); + if let Some((var_id, events)) = self.literal_watch_list.get(&literal) + && !events.is_empty() + { + let propagator = &mut propagators[propagator_id]; + for event in events.iter() { + let mut context = NotificationContext::new(trailed_values, assignments); + + let enqueue_decision = + propagator.notify(context.reborrow(), *var_id, event.into()); + + if enqueue_decision == EnqueueDecision::Enqueue { + propagator_queue + .enqueue_propagator(propagator_id, propagator.priority()); + } + } + } else { + let mut context = NotificationContext::new(trailed_values, assignments); + + let propagator = &mut propagators[propagator_id]; + let enqueue_decision = propagator + .notify_predicate_id_satisfied(context.reborrow(), predicate_id); + + if enqueue_decision == EnqueueDecision::Enqueue { + propagator_queue + .enqueue_propagator(propagator_id, propagator.priority()); + } } } } diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index 711561a7d..a468b965a 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -237,7 +237,7 @@ impl State { /// when backtracking past the checkpoint where the domain was created. pub fn new_literal(&mut self, name: Option>) -> Literal { let domain_id = self.new_interval_variable(0, 1, name); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } /// Creates a new interval variable with the given lower and upper bound. diff --git a/pumpkin-crates/core/src/engine/variables/literal.rs b/pumpkin-crates/core/src/engine/variables/literal.rs index 980ffa737..c75e5bfbe 100644 --- a/pumpkin-crates/core/src/engine/variables/literal.rs +++ b/pumpkin-crates/core/src/engine/variables/literal.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::ops::Not; use enumset::EnumSet; @@ -16,31 +17,31 @@ use crate::engine::predicates::predicate::Predicate; use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::engine::variables::AffineView; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct Literal { - integer_variable: AffineView, + pub(crate) inner: Predicate, +} + +impl Debug for Literal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } } impl Literal { /// Creates a new literal wrapping the provided [`DomainId`]. /// /// Note: the provided `domain_id` should have a domain between 0 and 1. - pub fn new(domain_id: DomainId) -> Literal { - Literal { - integer_variable: domain_id.scaled(1), - } - } - - pub fn get_integer_variable(&self) -> AffineView { - self.integer_variable + pub fn new(predicate: Predicate) -> Literal { + Literal { inner: predicate } } pub fn get_true_predicate(&self) -> Predicate { - self.lower_bound_predicate(1) + self.inner } pub fn get_false_predicate(&self) -> Predicate { - self.upper_bound_predicate(0) + !self.inner } } @@ -48,60 +49,119 @@ impl Not for Literal { type Output = Literal; fn not(self) -> Self::Output { - Literal { - integer_variable: self.integer_variable.scaled(-1).offset(1), - } + Literal { inner: !self.inner } } } -/// Forwards a function implementation to the field on self. -macro_rules! forward { - ( - $field:ident, - fn $(<$($lt:lifetime),+>)? $name:ident( - & $($lt_self:lifetime)? self, - $($param_name:ident : $param_type:ty),* - ) -> $return_type:ty - $(where $($where_clause:tt)*)? - ) => { - fn $name$(<$($lt),+>)?( - & $($lt_self)? self, - $($param_name: $param_type),* - ) -> $return_type $(where $($where_clause)*)? { - self.$field.$name($($param_name),*) +impl CheckerVariable for Literal { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + atomic.get_domain() == self.inner.get_domain() + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + if value == 0 { + !self.inner + } else if value >= 1 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() } } -} -impl CheckerVariable for Literal { - forward!(integer_variable, fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool); - forward!(integer_variable, fn atomic_less_than(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_greater_than(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_not_equal(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_equal(&self, value: i32) -> Predicate); - - forward!(integer_variable, fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt); - forward!(integer_variable, fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt); - forward!(integer_variable, fn induced_fixed_value(&self, variable_state: &VariableState) -> Option); - forward!(integer_variable, fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool); - forward!( - integer_variable, - fn <'this, 'state> induced_holes( - &'this self, - variable_state: &'state VariableState - ) -> impl Iterator + 'state - where - 'this: 'state, - ); - forward!( - integer_variable, - fn <'this, 'state> iter_induced_domain( - &'this self, - variable_state: &'state VariableState - ) -> Option + 'state> - where - 'this: 'state, - ); + fn atomic_greater_than(&self, value: i32) -> Predicate { + if value == 1 { + !self.inner + } else if value <= 0 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } + } + + fn atomic_equal(&self, value: i32) -> Predicate { + if value == 1 { + self.inner + } else if value == 0 { + !self.inner + } else { + Predicate::trivially_false() + } + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + if value == 1 { + !self.inner + } else if value == 0 { + self.inner + } else { + Predicate::trivially_true() + } + } + + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { + IntExt::Int( + variable_state + .is_true(&self.inner) + .then(|| 1) + .unwrap_or_default(), + ) + } + + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { + IntExt::Int(variable_state.is_true(&!self.inner).then(|| 0).unwrap_or(1)) + } + + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { + if variable_state.is_true(&self.inner) { + Some(1) + } else if variable_state.is_true(&!self.inner) { + Some(0) + } else { + None + } + } + + fn induced_domain_contains( + &self, + variable_state: &VariableState, + value: i32, + ) -> bool { + if value == 1 && !variable_state.is_true(&!self.inner) { + true + } else if value == 0 && !variable_state.is_true(&self.inner) { + true + } else { + false + } + } + + fn induced_holes<'this, 'state>( + &'this self, + _variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + std::iter::empty() + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + Some((0..=1).filter(|&value| { + if value == 0 { + !variable_state.is_true(&self.inner) + } else if value == 1 { + !variable_state.is_true(&!self.inner) + } else { + unreachable!() + } + })) + } } impl IntegerVariable for Literal { @@ -112,7 +172,11 @@ impl IntegerVariable for Literal { /// Literal that evaluate to false have a lower bound of 0. /// Unassigned literals have a lower bound of 0. fn lower_bound(&self, assignment: &Assignments) -> i32 { - self.integer_variable.lower_bound(assignment) + if assignment.is_predicate_satisfied(self.inner) { + 1 + } else { + 0 + } } fn lower_bound_at_trail_position( @@ -120,8 +184,11 @@ impl IntegerVariable for Literal { assignment: &Assignments, trail_position: usize, ) -> i32 { - self.integer_variable - .lower_bound_at_trail_position(assignment, trail_position) + if assignment.is_predicate_satisfied_at_trail_position(self.inner, trail_position) { + 1 + } else { + 0 + } } /// Returns the upper bound represented as a 0-1 value. @@ -129,7 +196,11 @@ impl IntegerVariable for Literal { /// Literal that evaluate to false have a upper bound of 0. /// Unassigned literals have a upper bound of 1. fn upper_bound(&self, assignment: &Assignments) -> i32 { - self.integer_variable.upper_bound(assignment) + if assignment.is_predicate_satisfied(!self.inner) { + 0 + } else { + 1 + } } fn upper_bound_at_trail_position( @@ -137,8 +208,11 @@ impl IntegerVariable for Literal { assignment: &Assignments, trail_position: usize, ) -> i32 { - self.integer_variable - .upper_bound_at_trail_position(assignment, trail_position) + if assignment.is_predicate_satisfied_at_trail_position(!self.inner, trail_position) { + 0 + } else { + 1 + } } /// Returns whether the input value, when interpreted as a bool, @@ -147,7 +221,13 @@ impl IntegerVariable for Literal { /// Literals that evaluate to false only contain value 0. /// Unassigned literals contain both values 0 and 1. fn contains(&self, assignment: &Assignments, value: i32) -> bool { - self.integer_variable.contains(assignment, value) + if value == 0 { + !assignment.is_predicate_satisfied(self.inner) + } else if value == 1 { + !assignment.is_predicate_satisfied(!self.inner) + } else { + false + } } fn contains_at_trail_position( @@ -156,40 +236,52 @@ impl IntegerVariable for Literal { value: i32, trail_position: usize, ) -> bool { - self.integer_variable - .contains_at_trail_position(assignment, value, trail_position) + if value == 0 { + !assignment.is_predicate_satisfied_at_trail_position(self.inner, trail_position) + } else if value == 1 { + !assignment.is_predicate_satisfied_at_trail_position(!self.inner, trail_position) + } else { + false + } } fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator { - self.integer_variable.iterate_domain(assignment) + (0..=1).filter(|&value| { + if value == 0 { + !assignment.is_predicate_satisfied(self.inner) + } else if value == 1 { + !assignment.is_predicate_satisfied(!self.inner) + } else { + unreachable!() + } + }) } fn watch_all(&self, watchers: &mut Watchers<'_>, events: EnumSet) { - self.integer_variable.watch_all(watchers, events) + watchers.watch_literal(*self, events) } fn unwatch_all(&self, watchers: &mut Watchers<'_>) { - self.integer_variable.unwatch_all(watchers) + watchers.unwatch_predicate(self.inner); } fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent { - self.integer_variable.unpack_event(event) + event.unwrap() } - fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, events: EnumSet) { - self.integer_variable.watch_all_backtrack(watchers, events) + fn watch_all_backtrack(&self, _watchers: &mut Watchers<'_>, _events: EnumSet) { + todo!() } fn get_holes_at_current_checkpoint( &self, - assignments: &Assignments, + _assignments: &Assignments, ) -> impl Iterator { - self.integer_variable - .get_holes_at_current_checkpoint(assignments) + std::iter::empty() } - fn get_holes(&self, assignments: &Assignments) -> impl Iterator { - self.integer_variable.get_holes(assignments) + fn get_holes(&self, _assignments: &Assignments) -> impl Iterator { + std::iter::empty() } } @@ -197,19 +289,43 @@ impl PredicateConstructor for Literal { type Value = i32; fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.lower_bound_predicate(bound) + if bound == 1 { + self.inner + } else if bound < 1 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } } fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.upper_bound_predicate(bound) + if bound == 0 { + !self.inner + } else if bound > 0 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } } fn equality_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.equality_predicate(bound) + if bound == 0 { + !self.inner + } else if bound == 1 { + self.inner + } else { + Predicate::trivially_true() + } } fn disequality_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.disequality_predicate(bound) + if bound == 0 { + self.inner + } else if bound == 1 { + !self.inner + } else { + Predicate::trivially_true() + } } } diff --git a/pumpkin-crates/core/src/propagation/constructor.rs b/pumpkin-crates/core/src/propagation/constructor.rs index 791627880..bf3100bc7 100644 --- a/pumpkin-crates/core/src/propagation/constructor.rs +++ b/pumpkin-crates/core/src/propagation/constructor.rs @@ -164,7 +164,12 @@ impl PropagatorConstructorContext<'_> { self.update_next_local_id(local_id); - let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + &mut self.state.notification_engine, + &mut self.state.trailed_values, + &self.state.assignments, + ); var.watch_all(&mut watchers, domain_events.events()); } @@ -207,7 +212,12 @@ impl PropagatorConstructorContext<'_> { self.update_next_local_id(local_id); - let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + &mut self.state.notification_engine, + &mut self.state.trailed_values, + &self.state.assignments, + ); var.watch_all_backtrack(&mut watchers, domain_events.events()); } diff --git a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs index 5c9d67a48..374b0ebab 100644 --- a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs +++ b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs @@ -152,7 +152,12 @@ impl<'a> PropagationContext<'a> { variable: local_id, }; - let mut watchers = Watchers::new(propagator_var, self.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + self.notification_engine, + self.trailed_values, + self.assignments, + ); var.watch_all(&mut watchers, domain_events.events()); } @@ -163,7 +168,12 @@ impl<'a> PropagationContext<'a> { variable: local_id, }; - let mut watchers = Watchers::new(propagator_var, self.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + self.notification_engine, + self.trailed_values, + self.assignments, + ); var.unwatch_all(&mut watchers); } diff --git a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs index 4409d4b7a..caff9253a 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs @@ -317,6 +317,7 @@ impl Propagator for NogoodPropagator { let reason = Reason::DynamicLazy(watcher.nogood_id.id as u64); let predicate = !context.get_predicate(nogood_predicates[0]); + let result = context.post( predicate, reason, diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index e1e616574..8f5a71b6f 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -144,10 +144,9 @@ where let old_bound = context.read_trailed_integer(self.current_bounds[index]); let new_bound = context.lower_bound(x_i) as i64; - pumpkin_assert_simple!( - old_bound < new_bound, - "propagator should only be triggered when lower bounds are tightened, old_bound={old_bound}, new_bound={new_bound}" - ); + if old_bound == new_bound { + return EnqueueDecision::Skip; + } context.write_trailed_integer( self.lower_bound_left_hand_side, @@ -220,7 +219,6 @@ where for (i, x_i) in self.x.iter().enumerate() { let bound = self.c - (lower_bound_left_hand_side - context.lower_bound(x_i)); - if context.upper_bound(x_i) > bound { context.post(predicate![x_i <= bound], i, &self.inference_code)?; } diff --git a/pumpkin-solver-py/src/model.rs b/pumpkin-solver-py/src/model.rs index e35b41880..cbf2b35b3 100644 --- a/pumpkin-solver-py/src/model.rs +++ b/pumpkin-solver-py/src/model.rs @@ -156,15 +156,6 @@ impl Model { self.solver.is_inconsistent() } - /// Get an integer variable for the given boolean. - /// - /// The integer is 1 if the boolean is `true`, and 0 if the boolean is `false`. - /// - /// This is deprecated as of 0.3.0. Prefer to use `BoolExpression.as_integer`. - fn boolean_as_integer(&mut self, boolean: BoolExpression) -> IntExpression { - boolean.as_integer() - } - /// Reify a predicate as an explicit boolean expression. /// /// A tag should be provided for this link to be identifiable in the proof. @@ -191,9 +182,6 @@ impl Model { self.solver.new_literal_for_predicate(solver_predicate, tag) }; - self.brancher - .add_domain(*literal.get_integer_variable().inner()); - Ok(literal.into()) } diff --git a/pumpkin-solver-py/src/variables.rs b/pumpkin-solver-py/src/variables.rs index 6e9e81527..ee52118b8 100644 --- a/pumpkin-solver-py/src/variables.rs +++ b/pumpkin-solver-py/src/variables.rs @@ -32,10 +32,6 @@ pub struct BoolExpression(pub Literal); #[pymethods] impl BoolExpression { - pub fn as_integer(&self) -> IntExpression { - IntExpression(self.0.get_integer_variable()) - } - pub fn negate(&self) -> BoolExpression { BoolExpression(!self.0) } diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs index 46b2ac0f4..105dc579c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs @@ -5,6 +5,7 @@ use std::rc::Rc; use flatzinc::Annotation; use pumpkin_core::Solver; use pumpkin_core::containers::HashMap; +use pumpkin_core::predicate; use pumpkin_core::variables::DomainId; use pumpkin_solver::core::variables::Literal; @@ -39,7 +40,7 @@ pub(crate) fn run( ) }); - let literal = Literal::new(domain_id); + let literal = Literal::new(predicate!(domain_id >= 1)); if is_output_variable(annos) { context.outputs.push(Output::bool(id, literal)); diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs index 8cc79b4a5..8a4803198 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs @@ -4,6 +4,7 @@ use std::collections::BTreeSet; use std::rc::Rc; use log::warn; +use pumpkin_core::predicate; use pumpkin_solver::Solver; use pumpkin_solver::core::containers::HashMap; use pumpkin_solver::core::containers::HashSet; @@ -139,7 +140,7 @@ impl CompilationContext<'_> { .variable_map .get(&self.equivalences.representative(identifier)) { - Ok(Literal::new(*domain_id)) + Ok(Literal::new(predicate!(domain_id >= 1))) } else { self.boolean_parameters .get(&self.equivalences.representative(identifier)) diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs index 52d1fc609..04a2ec2a2 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs @@ -146,10 +146,7 @@ fn create_from_search_strategy( } AnnExpr::Expr(expr) => { let bool_variable_array = context - .resolve_bool_variable_array(expr)? - .iter() - .map(|literal| literal.get_integer_variable()) - .collect::>(); + .resolve_bool_variable_array(expr)?; match values { AnnExpr::Expr(expr) => { diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs index 2b023d80d..cda074f82 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs @@ -11,6 +11,7 @@ use flatzinc::Expr; use flatzinc::IntExpr; use flatzinc::SetExpr; use flatzinc::SetLiteralExpr; +use pumpkin_core::predicate; use pumpkin_core::variables::Literal; use super::context::CompilationContext; @@ -49,7 +50,7 @@ pub(crate) fn run( .copied() .expect("referencing undefined boolean variable"); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } }) .collect(), diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index e72963c14..40ae0d86e 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -648,11 +648,12 @@ fn compile_bool2int( let a = context.resolve_bool_variable(&exprs[0])?; let b = context.resolve_integer_variable(&exprs[1])?; - Ok( - pumpkin_constraints::binary_equals(a.get_integer_variable(), b.scaled(1), constraint_tag) - .post(context.solver) - .is_ok(), - ) + todo!() + // Ok( + // pumpkin_constraints::binary_equals(a, b.scaled(1), constraint_tag) + // .post(context.solver) + // .is_ok(), + // ) } fn compile_bool_or( From 2363a7e4d8156610ec30e4311c44aef1cf895db4 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 26 Mar 2026 14:30:23 +0100 Subject: [PATCH 02/12] chore: cleanup --- .../constraints/src/constraints/boolean.rs | 19 +++-- .../domain_event_watch_list.rs | 15 +++- .../core/src/engine/notifications/mod.rs | 73 ++++++++++++++++++- .../core/src/engine/variables/literal.rs | 28 +++---- .../arithmetic/linear_less_or_equal.rs | 1 - .../flatzinc/compiler/post_constraints.rs | 6 +- 6 files changed, 112 insertions(+), 30 deletions(-) diff --git a/pumpkin-crates/constraints/src/constraints/boolean.rs b/pumpkin-crates/constraints/src/constraints/boolean.rs index b9ee45186..990b12354 100644 --- a/pumpkin-crates/constraints/src/constraints/boolean.rs +++ b/pumpkin-crates/constraints/src/constraints/boolean.rs @@ -7,7 +7,6 @@ use pumpkin_core::variables::DomainId; use pumpkin_core::variables::Literal; use pumpkin_core::variables::TransformableVariable; -use super::equals; use super::less_than_or_equals; /// Creates the [`Constraint`] `∑ weights_i * bools_i <= rhs`. @@ -36,7 +35,7 @@ pub fn boolean_equals( weights: weights.into(), bools: bools.into(), rhs, - constraint_tag, + _constraint_tag: constraint_tag, } } @@ -80,28 +79,28 @@ struct BooleanEqual { weights: Box<[i32]>, bools: Box<[Literal]>, rhs: DomainId, - constraint_tag: ConstraintTag, + _constraint_tag: ConstraintTag, } impl Constraint for BooleanEqual { - fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> { - let (domains, rhs_domain) = self.create_domains(); + fn post(self, _solver: &mut Solver) -> Result<(), ConstraintOperationError> { + let (_domains, _rhs_domain) = self.create_domains(); todo!(); - equals(domains, 0, self.constraint_tag).post(solver) + // equals(domains, 0, self.constraint_tag).post(solver) } fn implied_by( self, - solver: &mut Solver, - reification_literal: Literal, + _solver: &mut Solver, + _reification_literal: Literal, ) -> Result<(), ConstraintOperationError> { - let (domains, rhs_domain) = self.create_domains(); + let (_domains, _rhs_domain) = self.create_domains(); todo!(); - equals(domains, 0, self.constraint_tag).implied_by(solver, reification_literal) + // equals(domains, 0, self.constraint_tag).implied_by(solver, reification_literal) } } diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs index ff4dd07b8..b3a95e227 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs @@ -3,7 +3,6 @@ use std::fmt::Display; use enumset::EnumSet; use enumset::EnumSetType; -use crate::basic_types::PredicateId; use crate::containers::KeyedVec; use crate::engine::Assignments; use crate::engine::TrailedValues; @@ -42,6 +41,20 @@ impl<'a> Watchers<'a> { ) } + pub(crate) fn watch_literal_backtrack( + &mut self, + literal: Literal, + events: EnumSet, + ) { + self.notification_engine.watch_literal_backtrack( + literal, + events, + self.propagator_var, + self.trailed_values, + self.assignments, + ) + } + pub(crate) fn unwatch_predicate(&mut self, predicate: Predicate) { let predicate_id = self.notification_engine.get_id(predicate); self.notification_engine diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index 1c1e7ba92..e41717452 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -13,7 +13,6 @@ pub(crate) use predicate_notification::PredicateNotifier; use crate::basic_types::PredicateId; use crate::containers::HashMap; use crate::containers::KeyedVec; -use crate::containers::StorageKey; use crate::engine::Assignments; use crate::engine::PropagatorQueue; use crate::engine::TrailedValues; @@ -43,11 +42,13 @@ pub(crate) struct NotificationEngine { pub(crate) watch_list_predicate_id: KeyedVec>, // TODO: Should use direct hashing pub(crate) literal_watch_list: HashMap)>, + pub(crate) literal_watch_list_backtrack: HashMap)>, /// Events which have occurred since the last round of notifications have taken place events: EventSink, /// Backtrack events which have occurred since the last of backtrack notifications have taken /// place backtrack_events: EventSink, + backtrack_events_literals: Vec<(Literal, PropagatorId)>, } impl Default for NotificationEngine { @@ -56,10 +57,12 @@ impl Default for NotificationEngine { watch_list_domain_events: Default::default(), watch_list_predicate_id: Default::default(), literal_watch_list: Default::default(), + literal_watch_list_backtrack: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: 0, events: Default::default(), backtrack_events: Default::default(), + backtrack_events_literals: Default::default(), }; // Grow for the dummy predicate result.grow(); @@ -83,6 +86,9 @@ impl NotificationEngine { last_notified_trail_index: usize::MAX, events: Default::default(), backtrack_events: Default::default(), + literal_watch_list: Default::default(), + literal_watch_list_backtrack: Default::default(), + backtrack_events_literals: Default::default(), }; // Grow for the dummy predicate result.grow(); @@ -251,6 +257,57 @@ impl NotificationEngine { } } + pub(crate) fn watch_literal_backtrack( + &mut self, + literal: Literal, + events: EnumSet, + propagator_var: PropagatorVarId, + trailed_values: &mut TrailedValues, + assignments: &Assignments, + ) { + let entry = self + .literal_watch_list_backtrack + .entry(literal) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + + for event in events { + match event { + DomainEvent::Assign => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::LowerBound => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::UpperBound => { + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::Removal => {} + }; + } + } + pub(crate) fn watch_all_backtrack( &mut self, domain: DomainId, @@ -434,6 +491,18 @@ impl NotificationEngine { } } } + + for (literal, propagator_id) in self.backtrack_events_literals.drain(..).collect::>() + { + if let Some((var_id, events)) = self.literal_watch_list_backtrack.get(&literal) { + let propagator = &mut propagators[propagator_id]; + for event in events.iter() { + let mut context = NotificationContext::new(trailed_values, assignments); + + propagator.notify_backtrack(context.domains(), *var_id, event.into()) + } + } + } true } @@ -459,6 +528,8 @@ impl NotificationEngine { if let Some((var_id, events)) = self.literal_watch_list.get(&literal) && !events.is_empty() { + self.backtrack_events_literals + .push((literal, propagator_id)); let propagator = &mut propagators[propagator_id]; for event in events.iter() { let mut context = NotificationContext::new(trailed_values, assignments); diff --git a/pumpkin-crates/core/src/engine/variables/literal.rs b/pumpkin-crates/core/src/engine/variables/literal.rs index c75e5bfbe..290a07999 100644 --- a/pumpkin-crates/core/src/engine/variables/literal.rs +++ b/pumpkin-crates/core/src/engine/variables/literal.rs @@ -6,7 +6,6 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; -use super::DomainId; use super::IntegerVariable; use super::TransformableVariable; use crate::engine::Assignments; @@ -29,7 +28,7 @@ impl Debug for Literal { } impl Literal { - /// Creates a new literal wrapping the provided [`DomainId`]. + /// Creates a new literal wrapping the provided [`Predicate`]. /// /// Note: the provided `domain_id` should have a domain between 0 and 1. pub fn new(predicate: Predicate) -> Literal { @@ -99,16 +98,19 @@ impl CheckerVariable for Literal { } fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { - IntExt::Int( - variable_state - .is_true(&self.inner) - .then(|| 1) - .unwrap_or_default(), - ) + IntExt::Int(if variable_state.is_true(&self.inner) { + 1 + } else { + Default::default() + }) } fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { - IntExt::Int(variable_state.is_true(&!self.inner).then(|| 0).unwrap_or(1)) + IntExt::Int(if variable_state.is_true(&!self.inner) { + 0 + } else { + 1 + }) } fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { @@ -128,10 +130,8 @@ impl CheckerVariable for Literal { ) -> bool { if value == 1 && !variable_state.is_true(&!self.inner) { true - } else if value == 0 && !variable_state.is_true(&self.inner) { - true } else { - false + value == 0 && !variable_state.is_true(&self.inner) } } @@ -269,8 +269,8 @@ impl IntegerVariable for Literal { event.unwrap() } - fn watch_all_backtrack(&self, _watchers: &mut Watchers<'_>, _events: EnumSet) { - todo!() + fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, events: EnumSet) { + watchers.watch_literal_backtrack(*self, events) } fn get_holes_at_current_checkpoint( diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index 8f5a71b6f..e90fe9a4d 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -3,7 +3,6 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::InferenceChecker; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; -use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; use pumpkin_core::predicates::Predicate; diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index 40ae0d86e..b7295a3b4 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -637,7 +637,7 @@ fn compile_bool_and( fn compile_bool2int( context: &mut CompilationContext<'_>, exprs: &[flatzinc::Expr], - constraint_tag: ConstraintTag, + _constraint_tag: ConstraintTag, ) -> Result { // TODO: Perhaps we want to add a phase in the compiler that directly uses the literal // corresponding to the predicate [b = 1] for the boolean parameter in this constraint. @@ -645,8 +645,8 @@ fn compile_bool2int( check_parameters!(exprs, 2, "bool2int"); - let a = context.resolve_bool_variable(&exprs[0])?; - let b = context.resolve_integer_variable(&exprs[1])?; + let _a = context.resolve_bool_variable(&exprs[0])?; + let _b = context.resolve_integer_variable(&exprs[1])?; todo!() // Ok( From 5224ac3c61c57eaaed637e496e6b19938b2503e7 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 26 Mar 2026 14:37:47 +0100 Subject: [PATCH 03/12] chore: place todo in bool2int --- .../bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index b7295a3b4..0f2eb7e1f 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -648,7 +648,8 @@ fn compile_bool2int( let _a = context.resolve_bool_variable(&exprs[0])?; let _b = context.resolve_integer_variable(&exprs[1])?; - todo!() + todo!(); + // Ok( // pumpkin_constraints::binary_equals(a, b.scaled(1), constraint_tag) // .post(context.solver) From 856e0256f5ce8e914a52535773ec7c634abf6f39 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 26 Mar 2026 15:24:27 +0100 Subject: [PATCH 04/12] refactor: avoid duplicate notification + properly notifying or removal --- .../core/src/engine/notifications/mod.rs | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index e41717452..71797cdde 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -252,7 +252,20 @@ impl NotificationEngine { assignments, ); } - DomainEvent::Removal => {} + DomainEvent::Removal => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } }; } } @@ -303,7 +316,20 @@ impl NotificationEngine { assignments, ); } - DomainEvent::Removal => {} + DomainEvent::Removal => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } }; } } @@ -527,6 +553,9 @@ impl NotificationEngine { let literal = Literal::new(predicate); if let Some((var_id, events)) = self.literal_watch_list.get(&literal) && !events.is_empty() + && !self + .backtrack_events_literals + .contains(&(literal, propagator_id)) { self.backtrack_events_literals .push((literal, propagator_id)); From ff2138984720f3f088f84b00d5f9b796ed360876 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 26 Mar 2026 15:34:00 +0100 Subject: [PATCH 05/12] fix: bug in notifications not working properly --- .../core/src/engine/notifications/mod.rs | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index 71797cdde..3ab43bd8f 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -41,8 +41,10 @@ pub(crate) struct NotificationEngine { /// The watch list from predicates to propagators. pub(crate) watch_list_predicate_id: KeyedVec>, // TODO: Should use direct hashing - pub(crate) literal_watch_list: HashMap)>, - pub(crate) literal_watch_list_backtrack: HashMap)>, + pub(crate) literal_watch_list: + HashMap)>>, + pub(crate) literal_watch_list_backtrack: + HashMap)>>, /// Events which have occurred since the last round of notifications have taken place events: EventSink, /// Backtrack events which have occurred since the last of backtrack notifications have taken @@ -217,6 +219,8 @@ impl NotificationEngine { let entry = self .literal_watch_list .entry(literal) + .or_default() + .entry(propagator_var.propagator) .or_insert((propagator_var.variable, events)); entry.1 |= events; @@ -281,6 +285,8 @@ impl NotificationEngine { let entry = self .literal_watch_list_backtrack .entry(literal) + .or_default() + .entry(propagator_var.propagator) .or_insert((propagator_var.variable, events)); entry.1 |= events; @@ -520,7 +526,11 @@ impl NotificationEngine { for (literal, propagator_id) in self.backtrack_events_literals.drain(..).collect::>() { - if let Some((var_id, events)) = self.literal_watch_list_backtrack.get(&literal) { + if let Some(Some((var_id, events))) = self + .literal_watch_list_backtrack + .get(&literal) + .map(|inner| inner.get(&propagator_id)) + { let propagator = &mut propagators[propagator_id]; for event in events.iter() { let mut context = NotificationContext::new(trailed_values, assignments); @@ -551,12 +561,18 @@ impl NotificationEngine { for propagator_id in propagators_to_notify { let predicate = self.predicate_notifier.get_predicate(predicate_id); let literal = Literal::new(predicate); - if let Some((var_id, events)) = self.literal_watch_list.get(&literal) - && !events.is_empty() - && !self - .backtrack_events_literals - .contains(&(literal, propagator_id)) + if let Some(Some((var_id, events))) = self + .literal_watch_list + .get(&literal) + .map(|inner| inner.get(&propagator_id)) { + if events.is_empty() + || self + .backtrack_events_literals + .contains(&(literal, propagator_id)) + { + continue; + } self.backtrack_events_literals .push((literal, propagator_id)); let propagator = &mut propagators[propagator_id]; From 233f391abbb7565baf645e945377d0442915a1ee Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 26 Mar 2026 15:34:57 +0100 Subject: [PATCH 06/12] chore: remove check for duplicate notifications in system --- .../src/propagators/arithmetic/linear_less_or_equal.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index e90fe9a4d..3eefa9b6c 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -3,6 +3,7 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::InferenceChecker; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; +use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; use pumpkin_core::predicates::Predicate; @@ -143,9 +144,7 @@ where let old_bound = context.read_trailed_integer(self.current_bounds[index]); let new_bound = context.lower_bound(x_i) as i64; - if old_bound == new_bound { - return EnqueueDecision::Skip; - } + pumpkin_assert_simple!(new_bound > old_bound); context.write_trailed_integer( self.lower_bound_left_hand_side, From 22561ace9976c56e416f908f6fcdfa1f827735b3 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Fri, 27 Mar 2026 12:29:46 +0100 Subject: [PATCH 07/12] refactor: create structure for watch list literals and predicate ids --- .../domain_event_notification/mod.rs | 2 + .../watch_list_predicate_id.rs | 87 +++++++++++++++++ .../core/src/engine/notifications/mod.rs | 95 ++++++++----------- 3 files changed, 127 insertions(+), 57 deletions(-) create mode 100644 pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs index ddf607046..baac4a23d 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs @@ -6,3 +6,5 @@ pub use domain_event_watch_list::DomainEvent; pub(crate) use domain_event_watch_list::WatchListDomainEvents; pub(crate) use domain_event_watch_list::Watchers; pub(crate) use event_sink::*; +pub(crate) use watch_list_predicate_id::PredicateWatchList; +mod watch_list_predicate_id; diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs new file mode 100644 index 000000000..b5b719987 --- /dev/null +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs @@ -0,0 +1,87 @@ +use enumset::EnumSet; + +use crate::basic_types::PredicateId; +use crate::containers::HashMap; +use crate::containers::KeyedVec; +use crate::propagation::DomainEvent; +use crate::propagation::LocalId; +use crate::propagation::PropagatorVarId; +use crate::state::PropagatorId; +use crate::variables::Literal; + +#[derive(Debug, Default, Clone)] +pub(crate) struct PredicateWatchList { + /// The watch list from predicates to propagators. + pub(crate) watch_list_predicate_id: KeyedVec>, + // TODO: Should use direct hashing + pub(crate) literal_watch_list: + HashMap)>>, + pub(crate) literal_watch_list_backtrack: + HashMap)>>, +} + +impl PredicateWatchList { + pub(crate) fn watch_predicate_id( + &mut self, + predicate_id: PredicateId, + propagator_id: PropagatorId, + ) { + self.watch_list_predicate_id + .accomodate(predicate_id, vec![]); + self.watch_list_predicate_id[predicate_id].push(propagator_id); + } + + pub(crate) fn watchers_predicate_id_mut( + &mut self, + predicate_id: PredicateId, + ) -> Option<&mut Vec> { + self.watch_list_predicate_id.get_mut(predicate_id) + } + + pub(crate) fn watchers_predicate_id( + &self, + predicate_id: PredicateId, + ) -> Option<&Vec> { + self.watch_list_predicate_id.get(predicate_id) + } + + pub(crate) fn insert_literal_watcher( + &mut self, + literal: Literal, + propagator_var: PropagatorVarId, + events: EnumSet, + ) { + let entry = self + .literal_watch_list + .entry(literal) + .or_default() + .entry(propagator_var.propagator) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + } + + pub(crate) fn insert_literal_watcher_backtrack( + &mut self, + literal: Literal, + propagator_var: PropagatorVarId, + events: EnumSet, + ) { + let entry = self + .literal_watch_list_backtrack + .entry(literal) + .or_default() + .entry(propagator_var.propagator) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + } + + pub(crate) fn watchers_literal_propagator( + &self, + literal: Literal, + propagator_id: PropagatorId, + ) -> Option<&(LocalId, EnumSet)> { + self.literal_watch_list + .get(&literal) + .and_then(|inner| inner.get(&propagator_id)) + } +} diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index 3ab43bd8f..dde1278b1 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -11,11 +11,10 @@ use enumset::EnumSet; pub(crate) use predicate_notification::PredicateNotifier; use crate::basic_types::PredicateId; -use crate::containers::HashMap; -use crate::containers::KeyedVec; use crate::engine::Assignments; use crate::engine::PropagatorQueue; use crate::engine::TrailedValues; +use crate::engine::notifications::domain_event_notification::PredicateWatchList; use crate::predicates::Predicate; use crate::propagation::Domains; use crate::propagation::EnqueueDecision; @@ -38,13 +37,7 @@ pub(crate) struct NotificationEngine { /// Contains information on which propagator to notify upon /// integer events, e.g., lower or upper bound change of a variable. watch_list_domain_events: WatchListDomainEvents, - /// The watch list from predicates to propagators. - pub(crate) watch_list_predicate_id: KeyedVec>, - // TODO: Should use direct hashing - pub(crate) literal_watch_list: - HashMap)>>, - pub(crate) literal_watch_list_backtrack: - HashMap)>>, + watch_list_predicate_ids: PredicateWatchList, /// Events which have occurred since the last round of notifications have taken place events: EventSink, /// Backtrack events which have occurred since the last of backtrack notifications have taken @@ -57,9 +50,7 @@ impl Default for NotificationEngine { fn default() -> Self { let mut result = Self { watch_list_domain_events: Default::default(), - watch_list_predicate_id: Default::default(), - literal_watch_list: Default::default(), - literal_watch_list_backtrack: Default::default(), + watch_list_predicate_ids: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: 0, events: Default::default(), @@ -83,13 +74,11 @@ impl NotificationEngine { let mut result = Self { watch_list_domain_events, - watch_list_predicate_id: Default::default(), + watch_list_predicate_ids: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: usize::MAX, events: Default::default(), backtrack_events: Default::default(), - literal_watch_list: Default::default(), - literal_watch_list_backtrack: Default::default(), backtrack_events_literals: Default::default(), }; // Grow for the dummy predicate @@ -183,9 +172,8 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - self.watch_list_predicate_id - .accomodate(predicate_id, vec![]); - self.watch_list_predicate_id[predicate_id].push(propagator_id); + self.watch_list_predicate_ids + .watch_predicate_id(predicate_id, propagator_id); self.predicate_notifier .track_predicate(predicate_id, trailed_values, assignments); @@ -196,14 +184,17 @@ impl NotificationEngine { predicate_id: PredicateId, propagator_to_unwatch: PropagatorId, ) { - let watch_list = &mut self.watch_list_predicate_id[predicate_id]; - - let index = watch_list - .iter() - .position(|&watched_propagator| watched_propagator == propagator_to_unwatch) - .expect("cannot unwatch a (predicate, propagator) pair if it was not watched"); + if let Some(watch_list) = self + .watch_list_predicate_ids + .watchers_predicate_id_mut(predicate_id) + { + let index = watch_list + .iter() + .position(|&watched_propagator| watched_propagator == propagator_to_unwatch) + .expect("cannot unwatch a (predicate, propagator) pair if it was not watched"); - let _ = watch_list.swap_remove(index); + let _ = watch_list.swap_remove(index); + } // TODO: Can we remove the predicate from being tracked if it does not have watchers? } @@ -216,13 +207,8 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - let entry = self - .literal_watch_list - .entry(literal) - .or_default() - .entry(propagator_var.propagator) - .or_insert((propagator_var.variable, events)); - entry.1 |= events; + self.watch_list_predicate_ids + .insert_literal_watcher(literal, propagator_var, events); for event in events { match event { @@ -282,13 +268,8 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - let entry = self - .literal_watch_list_backtrack - .entry(literal) - .or_default() - .entry(propagator_var.propagator) - .or_insert((propagator_var.variable, events)); - entry.1 |= events; + self.watch_list_predicate_ids + .insert_literal_watcher_backtrack(literal, propagator_var, events); for event in events { match event { @@ -526,10 +507,9 @@ impl NotificationEngine { for (literal, propagator_id) in self.backtrack_events_literals.drain(..).collect::>() { - if let Some(Some((var_id, events))) = self - .literal_watch_list_backtrack - .get(&literal) - .map(|inner| inner.get(&propagator_id)) + if let Some((var_id, events)) = self + .watch_list_predicate_ids + .watchers_literal_propagator(literal, propagator_id) { let propagator = &mut propagators[propagator_id]; for event in events.iter() { @@ -555,27 +535,28 @@ impl NotificationEngine { .drain_satisfied_predicates() .collect::>() { - if let Some(watch_list) = self.watch_list_predicate_id.get(predicate_id) { - let propagators_to_notify = watch_list.iter().copied(); - - for propagator_id in propagators_to_notify { + if let Some(watchers) = self + .watch_list_predicate_ids + .watchers_predicate_id(predicate_id) + { + for propagator_id in watchers { let predicate = self.predicate_notifier.get_predicate(predicate_id); let literal = Literal::new(predicate); - if let Some(Some((var_id, events))) = self - .literal_watch_list - .get(&literal) - .map(|inner| inner.get(&propagator_id)) + if let Some((var_id, events)) = self + .watch_list_predicate_ids + .watchers_literal_propagator(literal, *propagator_id) { if events.is_empty() || self .backtrack_events_literals - .contains(&(literal, propagator_id)) + .contains(&(literal, *propagator_id)) { continue; } self.backtrack_events_literals - .push((literal, propagator_id)); - let propagator = &mut propagators[propagator_id]; + .push((literal, *propagator_id)); + + let propagator = &mut propagators[*propagator_id]; for event in events.iter() { let mut context = NotificationContext::new(trailed_values, assignments); @@ -584,19 +565,19 @@ impl NotificationEngine { if enqueue_decision == EnqueueDecision::Enqueue { propagator_queue - .enqueue_propagator(propagator_id, propagator.priority()); + .enqueue_propagator(*propagator_id, propagator.priority()); } } } else { let mut context = NotificationContext::new(trailed_values, assignments); - let propagator = &mut propagators[propagator_id]; + let propagator = &mut propagators[*propagator_id]; let enqueue_decision = propagator .notify_predicate_id_satisfied(context.reborrow(), predicate_id); if enqueue_decision == EnqueueDecision::Enqueue { propagator_queue - .enqueue_propagator(propagator_id, propagator.priority()); + .enqueue_propagator(*propagator_id, propagator.priority()); } } } From 0657c5760d661c02261cd43a271ab907df30c231 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Wed, 1 Apr 2026 15:02:11 +0200 Subject: [PATCH 08/12] fix: add enum for integer variable to be able to post elements --- .../constraints/src/constraints/boolean.rs | 40 +- pumpkin-crates/core/src/api/mod.rs | 1 + .../src/engine/variables/integer_variable.rs | 350 ++++++++++++++++++ .../core/src/engine/variables/mod.rs | 1 + .../flatzinc/compiler/post_constraints.rs | 21 +- 5 files changed, 381 insertions(+), 32 deletions(-) diff --git a/pumpkin-crates/constraints/src/constraints/boolean.rs b/pumpkin-crates/constraints/src/constraints/boolean.rs index 990b12354..ddc6bf689 100644 --- a/pumpkin-crates/constraints/src/constraints/boolean.rs +++ b/pumpkin-crates/constraints/src/constraints/boolean.rs @@ -4,10 +4,12 @@ use pumpkin_core::constraints::Constraint; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::variables::AffineView; use pumpkin_core::variables::DomainId; +use pumpkin_core::variables::IntegerVariableEnum; use pumpkin_core::variables::Literal; use pumpkin_core::variables::TransformableVariable; use super::less_than_or_equals; +use crate::equals; /// Creates the [`Constraint`] `∑ weights_i * bools_i <= rhs`. pub fn boolean_less_than_or_equals( @@ -35,7 +37,7 @@ pub fn boolean_equals( weights: weights.into(), bools: bools.into(), rhs, - _constraint_tag: constraint_tag, + constraint_tag, } } @@ -79,40 +81,34 @@ struct BooleanEqual { weights: Box<[i32]>, bools: Box<[Literal]>, rhs: DomainId, - _constraint_tag: ConstraintTag, + constraint_tag: ConstraintTag, } impl Constraint for BooleanEqual { - fn post(self, _solver: &mut Solver) -> Result<(), ConstraintOperationError> { - let (_domains, _rhs_domain) = self.create_domains(); - - todo!(); + fn post(self, solver: &mut Solver) -> Result<(), ConstraintOperationError> { + let domains = self.create_domains(); - // equals(domains, 0, self.constraint_tag).post(solver) + equals(domains, 0, self.constraint_tag).post(solver) } fn implied_by( self, - _solver: &mut Solver, - _reification_literal: Literal, + solver: &mut Solver, + reification_literal: Literal, ) -> Result<(), ConstraintOperationError> { - let (_domains, _rhs_domain) = self.create_domains(); - - todo!(); + let domains = self.create_domains(); - // equals(domains, 0, self.constraint_tag).implied_by(solver, reification_literal) + equals(domains, 0, self.constraint_tag).implied_by(solver, reification_literal) } } impl BooleanEqual { - fn create_domains(&self) -> (Vec>, AffineView) { - ( - self.bools - .iter() - .enumerate() - .map(|(index, bool)| bool.scaled(self.weights[index])) - .collect(), - self.rhs.scaled(-1), - ) + fn create_domains(&self) -> Vec { + self.bools + .iter() + .enumerate() + .map(|(index, bool)| bool.scaled(self.weights[index]).into()) + .chain(std::iter::once(self.rhs.scaled(-1).into())) + .collect() } } diff --git a/pumpkin-crates/core/src/api/mod.rs b/pumpkin-crates/core/src/api/mod.rs index 22f3a90d7..da5dc5c99 100644 --- a/pumpkin-crates/core/src/api/mod.rs +++ b/pumpkin-crates/core/src/api/mod.rs @@ -52,6 +52,7 @@ pub mod variables { pub use crate::engine::variables::AffineView; pub use crate::engine::variables::DomainId; pub use crate::engine::variables::IntegerVariable; + pub use crate::engine::variables::IntegerVariableEnum; pub use crate::engine::variables::Literal; pub use crate::engine::variables::TransformableVariable; } diff --git a/pumpkin-crates/core/src/engine/variables/integer_variable.rs b/pumpkin-crates/core/src/engine/variables/integer_variable.rs index 09badb92e..e37d3bab2 100644 --- a/pumpkin-crates/core/src/engine/variables/integer_variable.rs +++ b/pumpkin-crates/core/src/engine/variables/integer_variable.rs @@ -10,6 +10,9 @@ use crate::engine::notifications::OpaqueDomainEvent; use crate::engine::notifications::Watchers; use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::predicates::Predicate; +use crate::variables::AffineView; +use crate::variables::DomainId; +use crate::variables::Literal; /// A trait specifying the required behaviour of an integer variable such as retrieving a /// lower-bound ([`IntegerVariable::lower_bound`]). @@ -70,3 +73,350 @@ pub trait IntegerVariable: /// Returns all of the holes in the domain fn get_holes(&self, assignments: &Assignments) -> impl Iterator; } + +#[derive(Debug, Clone, Copy)] +pub enum IntegerVariableEnum { + DomainId(AffineView), + Literal(AffineView), +} + +impl From> for IntegerVariableEnum { + fn from(value: AffineView) -> Self { + IntegerVariableEnum::DomainId(value) + } +} + +impl From for IntegerVariableEnum { + fn from(value: DomainId) -> Self { + IntegerVariableEnum::DomainId(value.scaled(1)) + } +} + +impl From> for IntegerVariableEnum { + fn from(value: AffineView) -> Self { + IntegerVariableEnum::Literal(value) + } +} + +impl From for IntegerVariableEnum { + fn from(value: Literal) -> Self { + IntegerVariableEnum::Literal(value.scaled(1)) + } +} + +impl PredicateConstructor for IntegerVariableEnum { + type Value = i32; + + fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.lower_bound_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.lower_bound_predicate(bound), + } + } + + fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.upper_bound_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.upper_bound_predicate(bound), + } + } + + fn equality_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.equality_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.equality_predicate(bound), + } + } + + fn disequality_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.disequality_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.disequality_predicate(bound), + } + } +} + +impl TransformableVariable for IntegerVariableEnum { + fn scaled(&self, scale: i32) -> IntegerVariableEnum { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.scaled(scale).into(), + IntegerVariableEnum::Literal(literal) => literal.scaled(scale).into(), + } + } + + fn offset(&self, offset: i32) -> IntegerVariableEnum { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.offset(offset).into(), + IntegerVariableEnum::Literal(literal) => literal.offset(offset).into(), + } + } +} + +impl IntegerVariable for IntegerVariableEnum { + type AffineView = IntegerVariableEnum; + + fn lower_bound(&self, assignment: &Assignments) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.lower_bound(assignment), + IntegerVariableEnum::Literal(literal) => literal.lower_bound(assignment), + } + } + + fn lower_bound_at_trail_position( + &self, + assignment: &Assignments, + trail_position: usize, + ) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.lower_bound_at_trail_position(assignment, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.lower_bound_at_trail_position(assignment, trail_position) + } + } + } + + fn upper_bound(&self, assignment: &Assignments) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.upper_bound(assignment), + IntegerVariableEnum::Literal(literal) => literal.upper_bound(assignment), + } + } + + fn upper_bound_at_trail_position( + &self, + assignment: &Assignments, + trail_position: usize, + ) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.upper_bound_at_trail_position(assignment, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.upper_bound_at_trail_position(assignment, trail_position) + } + } + } + + fn contains(&self, assignment: &Assignments, value: i32) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.contains(assignment, value), + IntegerVariableEnum::Literal(literal) => literal.contains(assignment, value), + } + } + + fn contains_at_trail_position( + &self, + assignment: &Assignments, + value: i32, + trail_position: usize, + ) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.contains_at_trail_position(assignment, value, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.contains_at_trail_position(assignment, value, trail_position) + } + } + } + + fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .iterate_domain(assignment) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .iterate_domain(assignment) + .collect::>() + .into_iter(), + } + } + + fn watch_all(&self, watchers: &mut Watchers<'_>, events: EnumSet) { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.watch_all(watchers, events), + IntegerVariableEnum::Literal(literal) => literal.watch_all(watchers, events), + } + } + + fn unwatch_all(&self, watchers: &mut Watchers<'_>) { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.unwatch_all(watchers), + IntegerVariableEnum::Literal(literal) => literal.unwatch_all(watchers), + } + } + + fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, events: EnumSet) { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.watch_all_backtrack(watchers, events) + } + IntegerVariableEnum::Literal(literal) => literal.watch_all_backtrack(watchers, events), + } + } + + fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.unpack_event(event), + IntegerVariableEnum::Literal(literal) => literal.unpack_event(event), + } + } + + fn get_holes_at_current_checkpoint( + &self, + assignments: &Assignments, + ) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .get_holes_at_current_checkpoint(assignments) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .get_holes_at_current_checkpoint(assignments) + .collect::>() + .into_iter(), + } + } + + fn get_holes(&self, assignments: &Assignments) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .get_holes(assignments) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .get_holes(assignments) + .collect::>() + .into_iter(), + } + } +} + +impl CheckerVariable for IntegerVariableEnum { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.does_atomic_constrain_self(atomic) + } + IntegerVariableEnum::Literal(literal) => literal.does_atomic_constrain_self(atomic), + } + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_less_than(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_less_than(value), + } + } + + fn atomic_greater_than(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_greater_than(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_greater_than(value), + } + } + + fn atomic_equal(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_equal(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_equal(value), + } + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_not_equal(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_not_equal(value), + } + } + + fn induced_lower_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_lower_bound(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_lower_bound(variable_state), + } + } + + fn induced_upper_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_upper_bound(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_upper_bound(variable_state), + } + } + + fn induced_fixed_value( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> Option { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_fixed_value(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_fixed_value(variable_state), + } + } + + fn induced_domain_contains( + &self, + variable_state: &pumpkin_checking::VariableState, + value: i32, + ) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_domain_contains(variable_state, value) + } + IntegerVariableEnum::Literal(literal) => { + literal.induced_domain_contains(variable_state, value) + } + } + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .induced_holes(variable_state) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .induced_holes(variable_state) + .collect::>() + .into_iter(), + } + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .iter_induced_domain(variable_state) + .map(|iter| iter.collect::>().into_iter()), + IntegerVariableEnum::Literal(literal) => literal + .iter_induced_domain(variable_state) + .map(|iter| iter.collect::>().into_iter()), + } + } +} diff --git a/pumpkin-crates/core/src/engine/variables/mod.rs b/pumpkin-crates/core/src/engine/variables/mod.rs index 42837e402..2c0d10bb6 100644 --- a/pumpkin-crates/core/src/engine/variables/mod.rs +++ b/pumpkin-crates/core/src/engine/variables/mod.rs @@ -14,5 +14,6 @@ pub use affine_view::AffineView; pub(crate) use domain_generator_iterator::DomainGeneratorIterator; pub use domain_id::DomainId; pub use integer_variable::IntegerVariable; +pub use integer_variable::IntegerVariableEnum; pub use literal::Literal; pub use transformable_variable::TransformableVariable; diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index 0f2eb7e1f..ae444b171 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -2,6 +2,7 @@ use std::rc::Rc; +use pumpkin_core::variables::IntegerVariableEnum; use pumpkin_propagators::disjunctive::ArgDisjunctiveTask; use pumpkin_solver::core::constraints::Constraint; use pumpkin_solver::core::constraints::NegatableConstraint; @@ -637,7 +638,7 @@ fn compile_bool_and( fn compile_bool2int( context: &mut CompilationContext<'_>, exprs: &[flatzinc::Expr], - _constraint_tag: ConstraintTag, + constraint_tag: ConstraintTag, ) -> Result { // TODO: Perhaps we want to add a phase in the compiler that directly uses the literal // corresponding to the predicate [b = 1] for the boolean parameter in this constraint. @@ -645,16 +646,16 @@ fn compile_bool2int( check_parameters!(exprs, 2, "bool2int"); - let _a = context.resolve_bool_variable(&exprs[0])?; - let _b = context.resolve_integer_variable(&exprs[1])?; - - todo!(); + let a = context.resolve_bool_variable(&exprs[0])?; + let b = context.resolve_integer_variable(&exprs[1])?; - // Ok( - // pumpkin_constraints::binary_equals(a, b.scaled(1), constraint_tag) - // .post(context.solver) - // .is_ok(), - // ) + Ok(pumpkin_constraints::binary_equals( + IntegerVariableEnum::Literal(a.scaled(1)), + IntegerVariableEnum::DomainId(b.scaled(1)), + constraint_tag, + ) + .post(context.solver) + .is_ok()) } fn compile_bool_or( From 1e321396071c7f363c7e07902b984b485d1e1f50 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Wed, 1 Apr 2026 15:45:04 +0200 Subject: [PATCH 09/12] fix: introduce enum for boolean + fixing backtrack events + correctly finding watchers for literals when negated --- .../watch_list_predicate_id.rs | 5 +++++ .../core/src/engine/notifications/mod.rs | 10 ++++++++-- pumpkin-crates/core/src/engine/state.rs | 1 + .../arithmetic/linear_less_or_equal.rs | 18 ++++++++++++++++++ 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs index b5b719987..539f5b20a 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs @@ -83,5 +83,10 @@ impl PredicateWatchList { self.literal_watch_list .get(&literal) .and_then(|inner| inner.get(&propagator_id)) + .or_else(|| { + self.literal_watch_list + .get(&(!literal)) + .and_then(|inner| inner.get(&propagator_id)) + }) } } diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index dde1278b1..f7c523e62 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -11,6 +11,7 @@ use enumset::EnumSet; pub(crate) use predicate_notification::PredicateNotifier; use crate::basic_types::PredicateId; +use crate::basic_types::Trail; use crate::engine::Assignments; use crate::engine::PropagatorQueue; use crate::engine::TrailedValues; @@ -43,7 +44,7 @@ pub(crate) struct NotificationEngine { /// Backtrack events which have occurred since the last of backtrack notifications have taken /// place backtrack_events: EventSink, - backtrack_events_literals: Vec<(Literal, PropagatorId)>, + backtrack_events_literals: Trail<(Literal, PropagatorId)>, } impl Default for NotificationEngine { @@ -478,6 +479,7 @@ impl NotificationEngine { pub(crate) fn process_backtrack_events( &mut self, + new_checkpoint: usize, assignments: &mut Assignments, trailed_values: &mut TrailedValues, propagators: &mut PropagatorStore, @@ -505,7 +507,10 @@ impl NotificationEngine { } } - for (literal, propagator_id) in self.backtrack_events_literals.drain(..).collect::>() + for (literal, propagator_id) in self + .backtrack_events_literals + .synchronise(new_checkpoint) + .collect::>() { if let Some((var_id, events)) = self .watch_list_predicate_ids @@ -682,6 +687,7 @@ impl NotificationEngine { self.predicate_notifier .predicate_id_assignments .new_checkpoint(); + self.backtrack_events_literals.new_checkpoint(); } pub(crate) fn debug_create_from_assignments(&mut self, assignments: &Assignments) { diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index d238d0e16..3824d1913 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -572,6 +572,7 @@ impl State { } let _ = self.notification_engine.process_backtrack_events( + checkpoint, &mut self.assignments, &mut self.trailed_values, &mut self.propagators, diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index 3fc1c066e..300165525 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -3,6 +3,7 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::InferenceChecker; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; +use pumpkin_core::asserts::pumpkin_assert_extreme; use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; @@ -75,6 +76,13 @@ where current_bounds.push(context.new_trailed_integer(context.lower_bound(x_i) as i64)); } + pumpkin_assert_extreme!( + lower_bound_left_hand_side + == x.iter().map(|x| context.lower_bound(x) as i64).sum::(), + "Expected {lower_bound_left_hand_side} to be equal to {}", + x.iter().map(|x| context.lower_bound(x)).sum::() + ); + let lower_bound_left_hand_side = context.new_trailed_integer(lower_bound_left_hand_side); LinearLessOrEqualPropagator { @@ -214,6 +222,16 @@ where return Ok(()); } }; + pumpkin_assert_extreme!( + lower_bound_left_hand_side + == self.x.iter().map(|x| context.lower_bound(x)).sum::(), + "Expected {lower_bound_left_hand_side} to be equal to {}\n{:?}", + self.x.iter().map(|x| context.lower_bound(x)).sum::(), + self.x + .iter() + .map(|x| (x, context.lower_bound(x))) + .collect::>() + ); for (i, x_i) in self.x.iter().enumerate() { let bound = self.c - (lower_bound_left_hand_side - context.lower_bound(x_i)); From b8c8b9b43397e14b2f576d0560af9abf1176d953 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Wed, 1 Apr 2026 16:00:10 +0200 Subject: [PATCH 10/12] fix: initial attempt at fix for python interface --- .../src/engine/variables/integer_variable.rs | 2 +- .../src/constraints/arguments.rs | 10 +++ pumpkin-solver-py/src/constraints/globals.rs | 66 +++++++++---------- pumpkin-solver-py/src/variables.rs | 27 ++++++++ pumpkin-solver-py/tests/test_constraints.py | 46 ++++--------- 5 files changed, 84 insertions(+), 67 deletions(-) diff --git a/pumpkin-crates/core/src/engine/variables/integer_variable.rs b/pumpkin-crates/core/src/engine/variables/integer_variable.rs index e37d3bab2..ab7d19fa9 100644 --- a/pumpkin-crates/core/src/engine/variables/integer_variable.rs +++ b/pumpkin-crates/core/src/engine/variables/integer_variable.rs @@ -74,7 +74,7 @@ pub trait IntegerVariable: fn get_holes(&self, assignments: &Assignments) -> impl Iterator; } -#[derive(Debug, Clone, Copy)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum IntegerVariableEnum { DomainId(AffineView), Literal(AffineView), diff --git a/pumpkin-solver-py/src/constraints/arguments.rs b/pumpkin-solver-py/src/constraints/arguments.rs index ee2425fa8..4f2303c12 100644 --- a/pumpkin-solver-py/src/constraints/arguments.rs +++ b/pumpkin-solver-py/src/constraints/arguments.rs @@ -1,9 +1,11 @@ use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; use crate::variables::BoolExpression; use crate::variables::IntExpression; +use crate::variables::IntegerVariableWrapper; /// Trait which helps to convert Python API types to the solver types when creating constraints. pub trait PythonConstraintArg { @@ -12,6 +14,14 @@ pub trait PythonConstraintArg { fn to_solver_constraint_argument(self) -> Self::Output; } +impl PythonConstraintArg for IntegerVariableWrapper { + type Output = IntegerVariableEnum; + + fn to_solver_constraint_argument(self) -> Self::Output { + self.inner + } +} + impl PythonConstraintArg for IntExpression { type Output = AffineView; diff --git a/pumpkin-solver-py/src/constraints/globals.rs b/pumpkin-solver-py/src/constraints/globals.rs index f9649c512..edb3f9d45 100644 --- a/pumpkin-solver-py/src/constraints/globals.rs +++ b/pumpkin-solver-py/src/constraints/globals.rs @@ -52,48 +52,48 @@ macro_rules! python_constraint { python_constraint! { Absolute: absolute { - signed: IntExpression, - absolute: IntExpression, + signed: IntegerVariableWrapper, + absolute: IntegerVariableWrapper, } } python_constraint! { AllDifferent: all_different { - variables: Vec, + variables: Vec, } } python_constraint! { BinaryEquals: binary_equals { - lhs: IntExpression, - rhs: IntExpression, + lhs: IntegerVariableWrapper, + rhs: IntegerVariableWrapper, } } python_constraint! { BinaryLessThan: binary_less_than { - lhs: IntExpression, - rhs: IntExpression, + lhs: IntegerVariableWrapper, + rhs: IntegerVariableWrapper, } } python_constraint! { BinaryLessThanEqual: binary_less_than_or_equals { - lhs: IntExpression, - rhs: IntExpression, + lhs: IntegerVariableWrapper, + rhs: IntegerVariableWrapper, } } python_constraint! { BinaryNotEquals: binary_not_equals { - lhs: IntExpression, - rhs: IntExpression, + lhs: IntegerVariableWrapper, + rhs: IntegerVariableWrapper, } } python_constraint! { Cumulative: cumulative { - start_times: Vec, + start_times: Vec, durations: Vec, resource_requirements: Vec, resource_capacity: i32, @@ -102,68 +102,68 @@ python_constraint! { python_constraint! { Division: division { - numerator: IntExpression, - denominator: IntExpression, - rhs: IntExpression, + numerator: IntegerVariableWrapper, + denominator: IntegerVariableWrapper, + rhs: IntegerVariableWrapper, } } python_constraint! { Element: element { - index: IntExpression, - array: Vec, - rhs: IntExpression, + index: IntegerVariableWrapper, + array: Vec, + rhs: IntegerVariableWrapper, } } python_constraint! { Equals: equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { LessThanOrEquals: less_than_or_equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { Maximum: maximum { - choices: Vec, - rhs: IntExpression, + choices: Vec, + rhs: IntegerVariableWrapper, } } python_constraint! { Minimum: minimum { - choices: Vec, - rhs: IntExpression, + choices: Vec, + rhs: IntegerVariableWrapper, } } python_constraint! { NotEquals: not_equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { Plus: plus { - a: IntExpression, - b: IntExpression, - c: IntExpression, + a: IntegerVariableWrapper, + b: IntegerVariableWrapper, + c: IntegerVariableWrapper, } } python_constraint! { Times: times { - a: IntExpression, - b: IntExpression, - c: IntExpression, + a: IntegerVariableWrapper, + b: IntegerVariableWrapper, + c: IntegerVariableWrapper, } } @@ -181,14 +181,14 @@ python_constraint! { python_constraint! { Table: table { - variables: Vec, + variables: Vec, table: Vec>, } } python_constraint! { NegativeTable: negative_table { - variables: Vec, + variables: Vec, table: Vec>, } } diff --git a/pumpkin-solver-py/src/variables.rs b/pumpkin-solver-py/src/variables.rs index ee52118b8..2b92072c8 100644 --- a/pumpkin-solver-py/src/variables.rs +++ b/pumpkin-solver-py/src/variables.rs @@ -1,10 +1,30 @@ use pumpkin_solver::core::predicate; use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; use pumpkin_solver::core::variables::TransformableVariable; use pyo3::prelude::*; +#[pyclass(eq, hash, frozen)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct IntegerVariableWrapper { + pub(crate) inner: IntegerVariableEnum, +} + +impl From for IntegerVariableWrapper { + fn from(value: IntegerVariableEnum) -> Self { + Self { inner: value } + } +} + +impl From for IntegerVariableWrapper { + fn from(value: IntExpression) -> Self { + let value: IntegerVariableEnum = value.0.into(); + value.into() + } +} + #[pyclass(eq, hash, frozen)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct IntExpression(pub AffineView); @@ -30,6 +50,13 @@ impl IntExpression { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct BoolExpression(pub Literal); +impl From for IntegerVariableWrapper { + fn from(value: BoolExpression) -> Self { + let value: IntegerVariableEnum = value.0.into(); + value.into() + } +} + #[pymethods] impl BoolExpression { pub fn negate(&self) -> BoolExpression { diff --git a/pumpkin-solver-py/tests/test_constraints.py b/pumpkin-solver-py/tests/test_constraints.py index 8d26295a8..ba085d13e 100644 --- a/pumpkin-solver-py/tests/test_constraints.py +++ b/pumpkin-solver-py/tests/test_constraints.py @@ -1,15 +1,14 @@ -""" -Generate constraints and expressions based on the grammar supported by the API +"""Generate constraints and expressions based on the grammar supported by the API. -Generates linear constraints, special operators and global constraints. -Whenever possible, the script also generates 'boolean as integer' versions of the arguments +Generates linear constraints, special operators and global constraints. Whenever possible, the script also generates +'boolean as integer' versions of the arguments """ from random import randint +import pumpkin_solver import pytest from pumpkin_solver import constraints -import pumpkin_solver def chain(*iterables): @@ -68,15 +67,11 @@ def create_linear_model(request): model = pumpkin_solver.Model() if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 rhs = 1 cons = None @@ -104,15 +99,11 @@ def create_operator_model(request): model = pumpkin_solver.Model() if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 rhs = model.new_integer_variable(-3, 5, name="rhs") cons = None @@ -128,9 +119,7 @@ def create_operator_model(request): if name == "max": cons = constraints.Maximum(args, rhs, model.new_constraint_tag()) if name == "element": - idx = model.new_integer_variable( - -1, 5, name="idx" - ) # sneaky, idx can be out of bounds + idx = model.new_integer_variable(-1, 5, name="idx") # sneaky, idx can be out of bounds cons = constraints.Element(idx, args, rhs, model.new_constraint_tag()) if not cons: @@ -151,16 +140,11 @@ def create_global_model(request): if name == "alldiff": if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() - for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled or bool: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 cons = constraints.AllDifferent(args, model.new_constraint_tag()) @@ -182,9 +166,7 @@ def create_global_model(request): start = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: start = [a.scaled(-2 * i) for i, a in enumerate(start)] - cons = constraints.Cumulative( - start, duration, demand, capacity, model.new_constraint_tag() - ) + cons = constraints.Cumulative(start, duration, demand, capacity, model.new_constraint_tag()) else: assert False, f"unknown global {name}" @@ -200,9 +182,7 @@ def global_model(request): def make_id(args): name, scaled, bool = args - return " ".join( - ["Scaled" if scaled else "Unscaled", "Boolean" if bool else "Integer", name] - ) + return " ".join(["Scaled" if scaled else "Unscaled", "Boolean" if bool else "Integer", name]) @pytest.mark.parametrize("linear_model", generate_linear(), indirect=True, ids=make_id) From 08122913f30d41219019f1b2d712527f410c2462 Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 2 Apr 2026 08:10:54 +0200 Subject: [PATCH 11/12] refactor: moving everything to IntExpression --- pumpkin-solver-py/src/brancher.rs | 4 +- .../src/constraints/arguments.rs | 13 +--- pumpkin-solver-py/src/constraints/globals.rs | 66 +++++++++---------- pumpkin-solver-py/src/variables.rs | 35 +++------- 4 files changed, 46 insertions(+), 72 deletions(-) diff --git a/pumpkin-solver-py/src/brancher.rs b/pumpkin-solver-py/src/brancher.rs index 0ae40e269..9f7b307c6 100644 --- a/pumpkin-solver-py/src/brancher.rs +++ b/pumpkin-solver-py/src/brancher.rs @@ -7,13 +7,13 @@ use pumpkin_solver::core::containers::HashMap; use pumpkin_solver::core::predicates::Predicate; use pumpkin_solver::core::results::SolutionReference; use pumpkin_solver::core::statistics::StatisticLogger; -use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use crate::variables::IntExpression; pub struct PythonBrancher { - warm_start: WarmStart>, + warm_start: WarmStart, default_brancher: DefaultBrancher, } diff --git a/pumpkin-solver-py/src/constraints/arguments.rs b/pumpkin-solver-py/src/constraints/arguments.rs index 4f2303c12..6257f4cec 100644 --- a/pumpkin-solver-py/src/constraints/arguments.rs +++ b/pumpkin-solver-py/src/constraints/arguments.rs @@ -1,11 +1,8 @@ -use pumpkin_solver::core::variables::AffineView; -use pumpkin_solver::core::variables::DomainId; use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; use crate::variables::BoolExpression; use crate::variables::IntExpression; -use crate::variables::IntegerVariableWrapper; /// Trait which helps to convert Python API types to the solver types when creating constraints. pub trait PythonConstraintArg { @@ -14,16 +11,8 @@ pub trait PythonConstraintArg { fn to_solver_constraint_argument(self) -> Self::Output; } -impl PythonConstraintArg for IntegerVariableWrapper { - type Output = IntegerVariableEnum; - - fn to_solver_constraint_argument(self) -> Self::Output { - self.inner - } -} - impl PythonConstraintArg for IntExpression { - type Output = AffineView; + type Output = IntegerVariableEnum; fn to_solver_constraint_argument(self) -> Self::Output { self.0 diff --git a/pumpkin-solver-py/src/constraints/globals.rs b/pumpkin-solver-py/src/constraints/globals.rs index edb3f9d45..f9649c512 100644 --- a/pumpkin-solver-py/src/constraints/globals.rs +++ b/pumpkin-solver-py/src/constraints/globals.rs @@ -52,48 +52,48 @@ macro_rules! python_constraint { python_constraint! { Absolute: absolute { - signed: IntegerVariableWrapper, - absolute: IntegerVariableWrapper, + signed: IntExpression, + absolute: IntExpression, } } python_constraint! { AllDifferent: all_different { - variables: Vec, + variables: Vec, } } python_constraint! { BinaryEquals: binary_equals { - lhs: IntegerVariableWrapper, - rhs: IntegerVariableWrapper, + lhs: IntExpression, + rhs: IntExpression, } } python_constraint! { BinaryLessThan: binary_less_than { - lhs: IntegerVariableWrapper, - rhs: IntegerVariableWrapper, + lhs: IntExpression, + rhs: IntExpression, } } python_constraint! { BinaryLessThanEqual: binary_less_than_or_equals { - lhs: IntegerVariableWrapper, - rhs: IntegerVariableWrapper, + lhs: IntExpression, + rhs: IntExpression, } } python_constraint! { BinaryNotEquals: binary_not_equals { - lhs: IntegerVariableWrapper, - rhs: IntegerVariableWrapper, + lhs: IntExpression, + rhs: IntExpression, } } python_constraint! { Cumulative: cumulative { - start_times: Vec, + start_times: Vec, durations: Vec, resource_requirements: Vec, resource_capacity: i32, @@ -102,68 +102,68 @@ python_constraint! { python_constraint! { Division: division { - numerator: IntegerVariableWrapper, - denominator: IntegerVariableWrapper, - rhs: IntegerVariableWrapper, + numerator: IntExpression, + denominator: IntExpression, + rhs: IntExpression, } } python_constraint! { Element: element { - index: IntegerVariableWrapper, - array: Vec, - rhs: IntegerVariableWrapper, + index: IntExpression, + array: Vec, + rhs: IntExpression, } } python_constraint! { Equals: equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { LessThanOrEquals: less_than_or_equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { Maximum: maximum { - choices: Vec, - rhs: IntegerVariableWrapper, + choices: Vec, + rhs: IntExpression, } } python_constraint! { Minimum: minimum { - choices: Vec, - rhs: IntegerVariableWrapper, + choices: Vec, + rhs: IntExpression, } } python_constraint! { NotEquals: not_equals { - terms: Vec, + terms: Vec, rhs: i32, } } python_constraint! { Plus: plus { - a: IntegerVariableWrapper, - b: IntegerVariableWrapper, - c: IntegerVariableWrapper, + a: IntExpression, + b: IntExpression, + c: IntExpression, } } python_constraint! { Times: times { - a: IntegerVariableWrapper, - b: IntegerVariableWrapper, - c: IntegerVariableWrapper, + a: IntExpression, + b: IntExpression, + c: IntExpression, } } @@ -181,14 +181,14 @@ python_constraint! { python_constraint! { Table: table { - variables: Vec, + variables: Vec, table: Vec>, } } python_constraint! { NegativeTable: negative_table { - variables: Vec, + variables: Vec, table: Vec>, } } diff --git a/pumpkin-solver-py/src/variables.rs b/pumpkin-solver-py/src/variables.rs index 2b92072c8..5c0f85e9d 100644 --- a/pumpkin-solver-py/src/variables.rs +++ b/pumpkin-solver-py/src/variables.rs @@ -1,5 +1,4 @@ use pumpkin_solver::core::predicate; -use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; @@ -8,30 +7,23 @@ use pyo3::prelude::*; #[pyclass(eq, hash, frozen)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct IntegerVariableWrapper { - pub(crate) inner: IntegerVariableEnum, -} +pub struct IntExpression(pub IntegerVariableEnum); -impl From for IntegerVariableWrapper { - fn from(value: IntegerVariableEnum) -> Self { - Self { inner: value } +impl From for IntExpression { + fn from(domain_id: DomainId) -> IntExpression { + IntExpression(domain_id.into()) } } -impl From for IntegerVariableWrapper { - fn from(value: IntExpression) -> Self { - let value: IntegerVariableEnum = value.0.into(); - value.into() +impl From for IntExpression { + fn from(value: Literal) -> Self { + IntExpression(value.into()) } } -#[pyclass(eq, hash, frozen)] -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct IntExpression(pub AffineView); - -impl From for IntExpression { - fn from(domain_id: DomainId) -> IntExpression { - IntExpression(domain_id.into()) +impl From for IntExpression { + fn from(value: BoolExpression) -> Self { + IntExpression(value.0.into()) } } @@ -50,13 +42,6 @@ impl IntExpression { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct BoolExpression(pub Literal); -impl From for IntegerVariableWrapper { - fn from(value: BoolExpression) -> Self { - let value: IntegerVariableEnum = value.0.into(); - value.into() - } -} - #[pymethods] impl BoolExpression { pub fn negate(&self) -> BoolExpression { From 63b8a1178993ab34e291c464775c26241a210a3d Mon Sep 17 00:00:00 2001 From: Imko Marijnissen Date: Thu, 2 Apr 2026 08:24:14 +0200 Subject: [PATCH 12/12] fix: issues with python interface --- pumpkin-solver-py/src/variables.rs | 14 ++++++++++++++ pumpkin-solver-py/tests/test_constraints.py | 6 +++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pumpkin-solver-py/src/variables.rs b/pumpkin-solver-py/src/variables.rs index 5c0f85e9d..d2882a7a8 100644 --- a/pumpkin-solver-py/src/variables.rs +++ b/pumpkin-solver-py/src/variables.rs @@ -44,9 +44,23 @@ pub struct BoolExpression(pub Literal); #[pymethods] impl BoolExpression { + pub fn scaled(&self, scale: i32) -> IntExpression { + let int_expr: IntExpression = (*self).into(); + int_expr.scaled(scale) + } + + pub fn offset(&self, offset: i32) -> IntExpression { + let int_expr: IntExpression = (*self).into(); + int_expr.offset(offset) + } + pub fn negate(&self) -> BoolExpression { BoolExpression(!self.0) } + + pub fn as_expression(&self) -> IntExpression { + (*self).into() + } } impl From for BoolExpression { diff --git a/pumpkin-solver-py/tests/test_constraints.py b/pumpkin-solver-py/tests/test_constraints.py index ba085d13e..674064084 100644 --- a/pumpkin-solver-py/tests/test_constraints.py +++ b/pumpkin-solver-py/tests/test_constraints.py @@ -67,7 +67,7 @@ def create_linear_model(request): model = pumpkin_solver.Model() if bool: - args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) @@ -99,7 +99,7 @@ def create_operator_model(request): model = pumpkin_solver.Model() if bool: - args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) @@ -140,7 +140,7 @@ def create_global_model(request): if name == "alldiff": if bool: - args = [model.new_boolean_variable(name=f"x[{i}]") for i in range(3)] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled or bool: # do scaling (0, -2, 4,...)