diff --git a/pumpkin-crates/constraints/src/constraints/boolean.rs b/pumpkin-crates/constraints/src/constraints/boolean.rs index 32539f39b..ddc6bf689 100644 --- a/pumpkin-crates/constraints/src/constraints/boolean.rs +++ b/pumpkin-crates/constraints/src/constraints/boolean.rs @@ -4,11 +4,12 @@ use pumpkin_core::constraints::Constraint; use pumpkin_core::proof::ConstraintTag; use pumpkin_core::variables::AffineView; use pumpkin_core::variables::DomainId; +use pumpkin_core::variables::IntegerVariableEnum; use pumpkin_core::variables::Literal; use pumpkin_core::variables::TransformableVariable; -use super::equals; use super::less_than_or_equals; +use crate::equals; /// Creates the [`Constraint`] `∑ weights_i * bools_i <= rhs`. pub fn boolean_less_than_or_equals( @@ -67,11 +68,11 @@ impl Constraint for BooleanLessThanOrEqual { } impl BooleanLessThanOrEqual { - fn create_domains(&self) -> Vec> { + fn create_domains(&self) -> Vec> { self.bools .iter() .enumerate() - .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) + .map(|(index, bool)| bool.scaled(self.weights[index])) .collect() } } @@ -102,12 +103,12 @@ impl Constraint for BooleanEqual { } impl BooleanEqual { - fn create_domains(&self) -> Vec> { + fn create_domains(&self) -> Vec { self.bools .iter() .enumerate() - .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) - .chain(std::iter::once(self.rhs.scaled(-1))) + .map(|(index, bool)| bool.scaled(self.weights[index]).into()) + .chain(std::iter::once(self.rhs.scaled(-1).into())) .collect() } } diff --git a/pumpkin-crates/core/src/api/mod.rs b/pumpkin-crates/core/src/api/mod.rs index 22f3a90d7..da5dc5c99 100644 --- a/pumpkin-crates/core/src/api/mod.rs +++ b/pumpkin-crates/core/src/api/mod.rs @@ -52,6 +52,7 @@ pub mod variables { pub use crate::engine::variables::AffineView; pub use crate::engine::variables::DomainId; pub use crate::engine::variables::IntegerVariable; + pub use crate::engine::variables::IntegerVariableEnum; pub use crate::engine::variables::Literal; pub use crate::engine::variables::TransformableVariable; } diff --git a/pumpkin-crates/core/src/api/solver.rs b/pumpkin-crates/core/src/api/solver.rs index e5f866dd3..757cbef19 100644 --- a/pumpkin-crates/core/src/api/solver.rs +++ b/pumpkin-crates/core/src/api/solver.rs @@ -102,7 +102,7 @@ pub struct Solver { impl Default for Solver { fn default() -> Self { let satisfaction_solver = ConstraintSatisfactionSolver::default(); - let true_literal = Literal::new(Predicate::trivially_true().get_domain()); + let true_literal = Literal::new(Predicate::trivially_true()); Self { satisfaction_solver, true_literal, @@ -114,7 +114,7 @@ impl Solver { /// Creates a solver with the provided [`SolverOptions`]. pub fn with_options(solver_options: SolverOptions) -> Self { let satisfaction_solver = ConstraintSatisfactionSolver::new(solver_options); - let true_literal = Literal::new(Predicate::trivially_true().get_domain()); + let true_literal = Literal::new(Predicate::trivially_true()); Self { satisfaction_solver, true_literal, diff --git a/pumpkin-crates/core/src/engine/cp/assignments.rs b/pumpkin-crates/core/src/engine/cp/assignments.rs index d0dad662e..904d207f5 100644 --- a/pumpkin-crates/core/src/engine/cp/assignments.rs +++ b/pumpkin-crates/core/src/engine/cp/assignments.rs @@ -288,6 +288,15 @@ impl Assignments { self.is_domain_assigned(var).then(|| var.lower_bound(self)) } + pub(crate) fn get_assigned_value_at_trail_position( + &self, + var: &Var, + trail_position: usize, + ) -> Option { + self.is_domain_assigned_at_trail_position(var, trail_position) + .then(|| var.lower_bound(self)) + } + pub(crate) fn is_decision_predicate(&self, predicate: &Predicate) -> bool { let domain = predicate.get_domain(); if let Some(trail_position) = self.get_trail_position(predicate) @@ -368,6 +377,15 @@ impl Assignments { var.lower_bound(self) == var.upper_bound(self) } + pub(crate) fn is_domain_assigned_at_trail_position( + &self, + var: &Var, + trail_position: usize, + ) -> bool { + var.lower_bound_at_trail_position(self, trail_position) + == var.upper_bound_at_trail_position(self, trail_position) + } + /// Returns the index of the trail entry at which point the given predicate became true. /// In case the predicate is not true, then the function returns None. /// Note that it is not necessary for the predicate to be explicitly present on the trail, @@ -617,6 +635,66 @@ impl Assignments { Ok(update_took_place) } + /// Determines whether the provided [`Predicate`] holds at the provided trail position. In case + /// the predicate is not assigned yet (neither true nor false), returns None. + pub(crate) fn evaluate_predicate_at_trail_position( + &self, + predicate: Predicate, + trail_position: usize, + ) -> Option { + let domain_id = predicate.get_domain(); + let value = predicate.get_right_hand_side(); + + match predicate.get_predicate_type() { + PredicateType::LowerBound => { + if self.get_lower_bound_at_trail_position(domain_id, trail_position) >= value { + Some(true) + } else if self.get_upper_bound_at_trail_position(domain_id, trail_position) < value + { + Some(false) + } else { + None + } + } + PredicateType::UpperBound => { + if self.get_upper_bound_at_trail_position(domain_id, trail_position) <= value { + Some(true) + } else if self.get_lower_bound_at_trail_position(domain_id, trail_position) > value + { + Some(false) + } else { + None + } + } + PredicateType::NotEqual => { + if !self.is_value_in_domain_at_trail_position(domain_id, value, trail_position) { + Some(true) + } else if let Some(assigned_value) = + self.get_assigned_value_at_trail_position(&domain_id, trail_position) + { + // Previous branch concluded the value is not in the domain, so if the variable + // is assigned, then it is assigned to the not equals value. + pumpkin_assert_simple!(assigned_value == value); + Some(false) + } else { + None + } + } + PredicateType::Equal => { + if !self.is_value_in_domain_at_trail_position(domain_id, value, trail_position) { + Some(false) + } else if let Some(assigned_value) = + self.get_assigned_value_at_trail_position(&domain_id, trail_position) + { + pumpkin_assert_moderate!(assigned_value == value); + Some(true) + } else { + None + } + } + } + } + /// Determines whether the provided [`Predicate`] holds in the current state of the /// [`Assignments`]. In case the predicate is not assigned yet (neither true nor false), /// returns None. @@ -673,6 +751,15 @@ impl Assignments { .is_some_and(|truth_value| truth_value) } + pub(crate) fn is_predicate_satisfied_at_trail_position( + &self, + predicate: Predicate, + trail_position: usize, + ) -> bool { + self.evaluate_predicate_at_trail_position(predicate, trail_position) + .is_some_and(|truth_value| truth_value) + } + #[allow(unused, reason = "makes sense to have in this API")] pub(crate) fn is_predicate_falsified(&self, predicate: Predicate) -> bool { self.evaluate_predicate(predicate) diff --git a/pumpkin-crates/core/src/engine/cp/test_solver.rs b/pumpkin-crates/core/src/engine/cp/test_solver.rs index 4bbcdf795..368c0208b 100644 --- a/pumpkin-crates/core/src/engine/cp/test_solver.rs +++ b/pumpkin-crates/core/src/engine/cp/test_solver.rs @@ -86,7 +86,7 @@ impl TestSolver { pub fn new_literal(&mut self) -> Literal { let domain_id = self.new_variable(0, 1); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } pub fn new_propagator( diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs index 9d3220c2c..b3a95e227 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/domain_event_watch_list.rs @@ -4,9 +4,13 @@ use enumset::EnumSet; use enumset::EnumSetType; use crate::containers::KeyedVec; +use crate::engine::Assignments; +use crate::engine::TrailedValues; use crate::engine::notifications::NotificationEngine; use crate::engine::variables::DomainId; +use crate::predicates::Predicate; use crate::propagation::PropagatorVarId; +use crate::variables::Literal; #[derive(Default, Debug, Clone)] pub(crate) struct WatchListDomainEvents { @@ -22,6 +26,40 @@ pub(crate) struct WatchListDomainEvents { pub struct Watchers<'a> { propagator_var: PropagatorVarId, notification_engine: &'a mut NotificationEngine, + trailed_values: &'a mut TrailedValues, + assignments: &'a Assignments, +} + +impl<'a> Watchers<'a> { + pub(crate) fn watch_literal(&mut self, literal: Literal, events: EnumSet) { + self.notification_engine.watch_literal( + literal, + events, + self.propagator_var, + self.trailed_values, + self.assignments, + ) + } + + pub(crate) fn watch_literal_backtrack( + &mut self, + literal: Literal, + events: EnumSet, + ) { + self.notification_engine.watch_literal_backtrack( + literal, + events, + self.propagator_var, + self.trailed_values, + self.assignments, + ) + } + + pub(crate) fn unwatch_predicate(&mut self, predicate: Predicate) { + let predicate_id = self.notification_engine.get_id(predicate); + self.notification_engine + .unwatch_predicate(predicate_id, self.propagator_var.propagator); + } } /// A description of the kinds of events that can happen on a domain variable. @@ -95,10 +133,14 @@ impl<'a> Watchers<'a> { pub(crate) fn new( propagator_var: PropagatorVarId, notification_engine: &'a mut NotificationEngine, + trailed_values: &'a mut TrailedValues, + assignments: &'a Assignments, ) -> Self { Watchers { propagator_var, notification_engine, + trailed_values, + assignments, } } diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs index ddf607046..baac4a23d 100644 --- a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/mod.rs @@ -6,3 +6,5 @@ pub use domain_event_watch_list::DomainEvent; pub(crate) use domain_event_watch_list::WatchListDomainEvents; pub(crate) use domain_event_watch_list::Watchers; pub(crate) use event_sink::*; +pub(crate) use watch_list_predicate_id::PredicateWatchList; +mod watch_list_predicate_id; diff --git a/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs new file mode 100644 index 000000000..539f5b20a --- /dev/null +++ b/pumpkin-crates/core/src/engine/notifications/domain_event_notification/watch_list_predicate_id.rs @@ -0,0 +1,92 @@ +use enumset::EnumSet; + +use crate::basic_types::PredicateId; +use crate::containers::HashMap; +use crate::containers::KeyedVec; +use crate::propagation::DomainEvent; +use crate::propagation::LocalId; +use crate::propagation::PropagatorVarId; +use crate::state::PropagatorId; +use crate::variables::Literal; + +#[derive(Debug, Default, Clone)] +pub(crate) struct PredicateWatchList { + /// The watch list from predicates to propagators. + pub(crate) watch_list_predicate_id: KeyedVec>, + // TODO: Should use direct hashing + pub(crate) literal_watch_list: + HashMap)>>, + pub(crate) literal_watch_list_backtrack: + HashMap)>>, +} + +impl PredicateWatchList { + pub(crate) fn watch_predicate_id( + &mut self, + predicate_id: PredicateId, + propagator_id: PropagatorId, + ) { + self.watch_list_predicate_id + .accomodate(predicate_id, vec![]); + self.watch_list_predicate_id[predicate_id].push(propagator_id); + } + + pub(crate) fn watchers_predicate_id_mut( + &mut self, + predicate_id: PredicateId, + ) -> Option<&mut Vec> { + self.watch_list_predicate_id.get_mut(predicate_id) + } + + pub(crate) fn watchers_predicate_id( + &self, + predicate_id: PredicateId, + ) -> Option<&Vec> { + self.watch_list_predicate_id.get(predicate_id) + } + + pub(crate) fn insert_literal_watcher( + &mut self, + literal: Literal, + propagator_var: PropagatorVarId, + events: EnumSet, + ) { + let entry = self + .literal_watch_list + .entry(literal) + .or_default() + .entry(propagator_var.propagator) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + } + + pub(crate) fn insert_literal_watcher_backtrack( + &mut self, + literal: Literal, + propagator_var: PropagatorVarId, + events: EnumSet, + ) { + let entry = self + .literal_watch_list_backtrack + .entry(literal) + .or_default() + .entry(propagator_var.propagator) + .or_insert((propagator_var.variable, events)); + entry.1 |= events; + } + + pub(crate) fn watchers_literal_propagator( + &self, + literal: Literal, + propagator_id: PropagatorId, + ) -> Option<&(LocalId, EnumSet)> { + self.literal_watch_list + .get(&literal) + .and_then(|inner| inner.get(&propagator_id)) + .or_else(|| { + self.literal_watch_list + .get(&(!literal)) + .and_then(|inner| inner.get(&propagator_id)) + }) + } +} diff --git a/pumpkin-crates/core/src/engine/notifications/mod.rs b/pumpkin-crates/core/src/engine/notifications/mod.rs index ba6a5e64e..f7c523e62 100644 --- a/pumpkin-crates/core/src/engine/notifications/mod.rs +++ b/pumpkin-crates/core/src/engine/notifications/mod.rs @@ -11,10 +11,11 @@ use enumset::EnumSet; pub(crate) use predicate_notification::PredicateNotifier; use crate::basic_types::PredicateId; -use crate::containers::KeyedVec; +use crate::basic_types::Trail; use crate::engine::Assignments; use crate::engine::PropagatorQueue; use crate::engine::TrailedValues; +use crate::engine::notifications::domain_event_notification::PredicateWatchList; use crate::predicates::Predicate; use crate::propagation::Domains; use crate::propagation::EnqueueDecision; @@ -26,6 +27,7 @@ use crate::propagation::store::PropagatorStore; use crate::pumpkin_assert_extreme; use crate::pumpkin_assert_simple; use crate::variables::DomainId; +use crate::variables::Literal; #[derive(Debug, Clone)] pub(crate) struct NotificationEngine { @@ -36,24 +38,25 @@ pub(crate) struct NotificationEngine { /// Contains information on which propagator to notify upon /// integer events, e.g., lower or upper bound change of a variable. watch_list_domain_events: WatchListDomainEvents, - /// The watch list from predicates to propagators. - pub(crate) watch_list_predicate_id: KeyedVec>, + watch_list_predicate_ids: PredicateWatchList, /// Events which have occurred since the last round of notifications have taken place events: EventSink, /// Backtrack events which have occurred since the last of backtrack notifications have taken /// place backtrack_events: EventSink, + backtrack_events_literals: Trail<(Literal, PropagatorId)>, } impl Default for NotificationEngine { fn default() -> Self { let mut result = Self { watch_list_domain_events: Default::default(), - watch_list_predicate_id: Default::default(), + watch_list_predicate_ids: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: 0, events: Default::default(), backtrack_events: Default::default(), + backtrack_events_literals: Default::default(), }; // Grow for the dummy predicate result.grow(); @@ -72,11 +75,12 @@ impl NotificationEngine { let mut result = Self { watch_list_domain_events, - watch_list_predicate_id: Default::default(), + watch_list_predicate_ids: Default::default(), predicate_notifier: Default::default(), last_notified_trail_index: usize::MAX, events: Default::default(), backtrack_events: Default::default(), + backtrack_events_literals: Default::default(), }; // Grow for the dummy predicate result.grow(); @@ -169,9 +173,8 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - self.watch_list_predicate_id - .accomodate(predicate_id, vec![]); - self.watch_list_predicate_id[predicate_id].push(propagator_id); + self.watch_list_predicate_ids + .watch_predicate_id(predicate_id, propagator_id); self.predicate_notifier .track_predicate(predicate_id, trailed_values, assignments); @@ -182,18 +185,143 @@ impl NotificationEngine { predicate_id: PredicateId, propagator_to_unwatch: PropagatorId, ) { - let watch_list = &mut self.watch_list_predicate_id[predicate_id]; - - let index = watch_list - .iter() - .position(|&watched_propagator| watched_propagator == propagator_to_unwatch) - .expect("cannot unwatch a (predicate, propagator) pair if it was not watched"); + if let Some(watch_list) = self + .watch_list_predicate_ids + .watchers_predicate_id_mut(predicate_id) + { + let index = watch_list + .iter() + .position(|&watched_propagator| watched_propagator == propagator_to_unwatch) + .expect("cannot unwatch a (predicate, propagator) pair if it was not watched"); - let _ = watch_list.swap_remove(index); + let _ = watch_list.swap_remove(index); + } // TODO: Can we remove the predicate from being tracked if it does not have watchers? } + pub(crate) fn watch_literal( + &mut self, + literal: Literal, + events: EnumSet, + propagator_var: PropagatorVarId, + trailed_values: &mut TrailedValues, + assignments: &Assignments, + ) { + self.watch_list_predicate_ids + .insert_literal_watcher(literal, propagator_var, events); + + for event in events { + match event { + DomainEvent::Assign => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::LowerBound => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::UpperBound => { + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::Removal => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + }; + } + } + + pub(crate) fn watch_literal_backtrack( + &mut self, + literal: Literal, + events: EnumSet, + propagator_var: PropagatorVarId, + trailed_values: &mut TrailedValues, + assignments: &Assignments, + ) { + self.watch_list_predicate_ids + .insert_literal_watcher_backtrack(literal, propagator_var, events); + + for event in events { + match event { + DomainEvent::Assign => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::LowerBound => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::UpperBound => { + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + DomainEvent::Removal => { + let _ = self.watch_predicate( + literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + let _ = self.watch_predicate( + !literal.inner, + propagator_var.propagator, + trailed_values, + assignments, + ); + } + }; + } + } + pub(crate) fn watch_all_backtrack( &mut self, domain: DomainId, @@ -351,6 +479,7 @@ impl NotificationEngine { pub(crate) fn process_backtrack_events( &mut self, + new_checkpoint: usize, assignments: &mut Assignments, trailed_values: &mut TrailedValues, propagators: &mut PropagatorStore, @@ -377,6 +506,24 @@ impl NotificationEngine { } } } + + for (literal, propagator_id) in self + .backtrack_events_literals + .synchronise(new_checkpoint) + .collect::>() + { + if let Some((var_id, events)) = self + .watch_list_predicate_ids + .watchers_literal_propagator(literal, propagator_id) + { + let propagator = &mut propagators[propagator_id]; + for event in events.iter() { + let mut context = NotificationContext::new(trailed_values, assignments); + + propagator.notify_backtrack(context.domains(), *var_id, event.into()) + } + } + } true } @@ -388,19 +535,55 @@ impl NotificationEngine { trailed_values: &mut TrailedValues, assignments: &Assignments, ) { - for predicate_id in self.predicate_notifier.drain_satisfied_predicates() { - if let Some(watch_list) = self.watch_list_predicate_id.get(predicate_id) { - let propagators_to_notify = watch_list.iter().copied(); - - for propagator_id in propagators_to_notify { - let mut context = NotificationContext::new(trailed_values, assignments); - - let propagator = &mut propagators[propagator_id]; - let enqueue_decision = - propagator.notify_predicate_id_satisfied(context.reborrow(), predicate_id); - - if enqueue_decision == EnqueueDecision::Enqueue { - propagator_queue.enqueue_propagator(propagator_id, propagator.priority()); + for predicate_id in self + .predicate_notifier + .drain_satisfied_predicates() + .collect::>() + { + if let Some(watchers) = self + .watch_list_predicate_ids + .watchers_predicate_id(predicate_id) + { + for propagator_id in watchers { + let predicate = self.predicate_notifier.get_predicate(predicate_id); + let literal = Literal::new(predicate); + if let Some((var_id, events)) = self + .watch_list_predicate_ids + .watchers_literal_propagator(literal, *propagator_id) + { + if events.is_empty() + || self + .backtrack_events_literals + .contains(&(literal, *propagator_id)) + { + continue; + } + self.backtrack_events_literals + .push((literal, *propagator_id)); + + let propagator = &mut propagators[*propagator_id]; + for event in events.iter() { + let mut context = NotificationContext::new(trailed_values, assignments); + + let enqueue_decision = + propagator.notify(context.reborrow(), *var_id, event.into()); + + if enqueue_decision == EnqueueDecision::Enqueue { + propagator_queue + .enqueue_propagator(*propagator_id, propagator.priority()); + } + } + } else { + let mut context = NotificationContext::new(trailed_values, assignments); + + let propagator = &mut propagators[*propagator_id]; + let enqueue_decision = propagator + .notify_predicate_id_satisfied(context.reborrow(), predicate_id); + + if enqueue_decision == EnqueueDecision::Enqueue { + propagator_queue + .enqueue_propagator(*propagator_id, propagator.priority()); + } } } } @@ -504,6 +687,7 @@ impl NotificationEngine { self.predicate_notifier .predicate_id_assignments .new_checkpoint(); + self.backtrack_events_literals.new_checkpoint(); } pub(crate) fn debug_create_from_assignments(&mut self, assignments: &Assignments) { diff --git a/pumpkin-crates/core/src/engine/state.rs b/pumpkin-crates/core/src/engine/state.rs index 998741fb4..3824d1913 100644 --- a/pumpkin-crates/core/src/engine/state.rs +++ b/pumpkin-crates/core/src/engine/state.rs @@ -166,7 +166,7 @@ impl State { /// when backtracking past the checkpoint where the domain was created. pub fn new_literal(&mut self, name: Option>) -> Literal { let domain_id = self.new_interval_variable(0, 1, name); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } /// Creates a new interval variable with the given lower and upper bound. @@ -572,6 +572,7 @@ impl State { } let _ = self.notification_engine.process_backtrack_events( + checkpoint, &mut self.assignments, &mut self.trailed_values, &mut self.propagators, diff --git a/pumpkin-crates/core/src/engine/variables/integer_variable.rs b/pumpkin-crates/core/src/engine/variables/integer_variable.rs index 09badb92e..ab7d19fa9 100644 --- a/pumpkin-crates/core/src/engine/variables/integer_variable.rs +++ b/pumpkin-crates/core/src/engine/variables/integer_variable.rs @@ -10,6 +10,9 @@ use crate::engine::notifications::OpaqueDomainEvent; use crate::engine::notifications::Watchers; use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::predicates::Predicate; +use crate::variables::AffineView; +use crate::variables::DomainId; +use crate::variables::Literal; /// A trait specifying the required behaviour of an integer variable such as retrieving a /// lower-bound ([`IntegerVariable::lower_bound`]). @@ -70,3 +73,350 @@ pub trait IntegerVariable: /// Returns all of the holes in the domain fn get_holes(&self, assignments: &Assignments) -> impl Iterator; } + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum IntegerVariableEnum { + DomainId(AffineView), + Literal(AffineView), +} + +impl From> for IntegerVariableEnum { + fn from(value: AffineView) -> Self { + IntegerVariableEnum::DomainId(value) + } +} + +impl From for IntegerVariableEnum { + fn from(value: DomainId) -> Self { + IntegerVariableEnum::DomainId(value.scaled(1)) + } +} + +impl From> for IntegerVariableEnum { + fn from(value: AffineView) -> Self { + IntegerVariableEnum::Literal(value) + } +} + +impl From for IntegerVariableEnum { + fn from(value: Literal) -> Self { + IntegerVariableEnum::Literal(value.scaled(1)) + } +} + +impl PredicateConstructor for IntegerVariableEnum { + type Value = i32; + + fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.lower_bound_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.lower_bound_predicate(bound), + } + } + + fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.upper_bound_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.upper_bound_predicate(bound), + } + } + + fn equality_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.equality_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.equality_predicate(bound), + } + } + + fn disequality_predicate(&self, bound: Self::Value) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.disequality_predicate(bound), + IntegerVariableEnum::Literal(literal) => literal.disequality_predicate(bound), + } + } +} + +impl TransformableVariable for IntegerVariableEnum { + fn scaled(&self, scale: i32) -> IntegerVariableEnum { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.scaled(scale).into(), + IntegerVariableEnum::Literal(literal) => literal.scaled(scale).into(), + } + } + + fn offset(&self, offset: i32) -> IntegerVariableEnum { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.offset(offset).into(), + IntegerVariableEnum::Literal(literal) => literal.offset(offset).into(), + } + } +} + +impl IntegerVariable for IntegerVariableEnum { + type AffineView = IntegerVariableEnum; + + fn lower_bound(&self, assignment: &Assignments) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.lower_bound(assignment), + IntegerVariableEnum::Literal(literal) => literal.lower_bound(assignment), + } + } + + fn lower_bound_at_trail_position( + &self, + assignment: &Assignments, + trail_position: usize, + ) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.lower_bound_at_trail_position(assignment, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.lower_bound_at_trail_position(assignment, trail_position) + } + } + } + + fn upper_bound(&self, assignment: &Assignments) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.upper_bound(assignment), + IntegerVariableEnum::Literal(literal) => literal.upper_bound(assignment), + } + } + + fn upper_bound_at_trail_position( + &self, + assignment: &Assignments, + trail_position: usize, + ) -> i32 { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.upper_bound_at_trail_position(assignment, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.upper_bound_at_trail_position(assignment, trail_position) + } + } + } + + fn contains(&self, assignment: &Assignments, value: i32) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.contains(assignment, value), + IntegerVariableEnum::Literal(literal) => literal.contains(assignment, value), + } + } + + fn contains_at_trail_position( + &self, + assignment: &Assignments, + value: i32, + trail_position: usize, + ) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.contains_at_trail_position(assignment, value, trail_position) + } + IntegerVariableEnum::Literal(literal) => { + literal.contains_at_trail_position(assignment, value, trail_position) + } + } + } + + fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .iterate_domain(assignment) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .iterate_domain(assignment) + .collect::>() + .into_iter(), + } + } + + fn watch_all(&self, watchers: &mut Watchers<'_>, events: EnumSet) { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.watch_all(watchers, events), + IntegerVariableEnum::Literal(literal) => literal.watch_all(watchers, events), + } + } + + fn unwatch_all(&self, watchers: &mut Watchers<'_>) { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.unwatch_all(watchers), + IntegerVariableEnum::Literal(literal) => literal.unwatch_all(watchers), + } + } + + fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, events: EnumSet) { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.watch_all_backtrack(watchers, events) + } + IntegerVariableEnum::Literal(literal) => literal.watch_all_backtrack(watchers, events), + } + } + + fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.unpack_event(event), + IntegerVariableEnum::Literal(literal) => literal.unpack_event(event), + } + } + + fn get_holes_at_current_checkpoint( + &self, + assignments: &Assignments, + ) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .get_holes_at_current_checkpoint(assignments) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .get_holes_at_current_checkpoint(assignments) + .collect::>() + .into_iter(), + } + } + + fn get_holes(&self, assignments: &Assignments) -> impl Iterator { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .get_holes(assignments) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .get_holes(assignments) + .collect::>() + .into_iter(), + } + } +} + +impl CheckerVariable for IntegerVariableEnum { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.does_atomic_constrain_self(atomic) + } + IntegerVariableEnum::Literal(literal) => literal.does_atomic_constrain_self(atomic), + } + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_less_than(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_less_than(value), + } + } + + fn atomic_greater_than(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_greater_than(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_greater_than(value), + } + } + + fn atomic_equal(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_equal(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_equal(value), + } + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id.atomic_not_equal(value), + IntegerVariableEnum::Literal(literal) => literal.atomic_not_equal(value), + } + } + + fn induced_lower_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_lower_bound(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_lower_bound(variable_state), + } + } + + fn induced_upper_bound( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> pumpkin_checking::IntExt { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_upper_bound(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_upper_bound(variable_state), + } + } + + fn induced_fixed_value( + &self, + variable_state: &pumpkin_checking::VariableState, + ) -> Option { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_fixed_value(variable_state) + } + IntegerVariableEnum::Literal(literal) => literal.induced_fixed_value(variable_state), + } + } + + fn induced_domain_contains( + &self, + variable_state: &pumpkin_checking::VariableState, + value: i32, + ) -> bool { + match self { + IntegerVariableEnum::DomainId(domain_id) => { + domain_id.induced_domain_contains(variable_state, value) + } + IntegerVariableEnum::Literal(literal) => { + literal.induced_domain_contains(variable_state, value) + } + } + } + + fn induced_holes<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .induced_holes(variable_state) + .collect::>() + .into_iter(), + IntegerVariableEnum::Literal(literal) => literal + .induced_holes(variable_state) + .collect::>() + .into_iter(), + } + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state pumpkin_checking::VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + match self { + IntegerVariableEnum::DomainId(domain_id) => domain_id + .iter_induced_domain(variable_state) + .map(|iter| iter.collect::>().into_iter()), + IntegerVariableEnum::Literal(literal) => literal + .iter_induced_domain(variable_state) + .map(|iter| iter.collect::>().into_iter()), + } + } +} diff --git a/pumpkin-crates/core/src/engine/variables/literal.rs b/pumpkin-crates/core/src/engine/variables/literal.rs index 980ffa737..290a07999 100644 --- a/pumpkin-crates/core/src/engine/variables/literal.rs +++ b/pumpkin-crates/core/src/engine/variables/literal.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::ops::Not; use enumset::EnumSet; @@ -5,7 +6,6 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; -use super::DomainId; use super::IntegerVariable; use super::TransformableVariable; use crate::engine::Assignments; @@ -16,31 +16,31 @@ use crate::engine::predicates::predicate::Predicate; use crate::engine::predicates::predicate_constructor::PredicateConstructor; use crate::engine::variables::AffineView; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct Literal { - integer_variable: AffineView, + pub(crate) inner: Predicate, +} + +impl Debug for Literal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } } impl Literal { - /// Creates a new literal wrapping the provided [`DomainId`]. + /// Creates a new literal wrapping the provided [`Predicate`]. /// /// Note: the provided `domain_id` should have a domain between 0 and 1. - pub fn new(domain_id: DomainId) -> Literal { - Literal { - integer_variable: domain_id.scaled(1), - } - } - - pub fn get_integer_variable(&self) -> AffineView { - self.integer_variable + pub fn new(predicate: Predicate) -> Literal { + Literal { inner: predicate } } pub fn get_true_predicate(&self) -> Predicate { - self.lower_bound_predicate(1) + self.inner } pub fn get_false_predicate(&self) -> Predicate { - self.upper_bound_predicate(0) + !self.inner } } @@ -48,60 +48,120 @@ impl Not for Literal { type Output = Literal; fn not(self) -> Self::Output { - Literal { - integer_variable: self.integer_variable.scaled(-1).offset(1), - } + Literal { inner: !self.inner } } } -/// Forwards a function implementation to the field on self. -macro_rules! forward { - ( - $field:ident, - fn $(<$($lt:lifetime),+>)? $name:ident( - & $($lt_self:lifetime)? self, - $($param_name:ident : $param_type:ty),* - ) -> $return_type:ty - $(where $($where_clause:tt)*)? - ) => { - fn $name$(<$($lt),+>)?( - & $($lt_self)? self, - $($param_name: $param_type),* - ) -> $return_type $(where $($where_clause)*)? { - self.$field.$name($($param_name),*) +impl CheckerVariable for Literal { + fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool { + atomic.get_domain() == self.inner.get_domain() + } + + fn atomic_less_than(&self, value: i32) -> Predicate { + if value == 0 { + !self.inner + } else if value >= 1 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() } } -} -impl CheckerVariable for Literal { - forward!(integer_variable, fn does_atomic_constrain_self(&self, atomic: &Predicate) -> bool); - forward!(integer_variable, fn atomic_less_than(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_greater_than(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_not_equal(&self, value: i32) -> Predicate); - forward!(integer_variable, fn atomic_equal(&self, value: i32) -> Predicate); - - forward!(integer_variable, fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt); - forward!(integer_variable, fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt); - forward!(integer_variable, fn induced_fixed_value(&self, variable_state: &VariableState) -> Option); - forward!(integer_variable, fn induced_domain_contains(&self, variable_state: &VariableState, value: i32) -> bool); - forward!( - integer_variable, - fn <'this, 'state> induced_holes( - &'this self, - variable_state: &'state VariableState - ) -> impl Iterator + 'state - where - 'this: 'state, - ); - forward!( - integer_variable, - fn <'this, 'state> iter_induced_domain( - &'this self, - variable_state: &'state VariableState - ) -> Option + 'state> - where - 'this: 'state, - ); + fn atomic_greater_than(&self, value: i32) -> Predicate { + if value == 1 { + !self.inner + } else if value <= 0 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } + } + + fn atomic_equal(&self, value: i32) -> Predicate { + if value == 1 { + self.inner + } else if value == 0 { + !self.inner + } else { + Predicate::trivially_false() + } + } + + fn atomic_not_equal(&self, value: i32) -> Predicate { + if value == 1 { + !self.inner + } else if value == 0 { + self.inner + } else { + Predicate::trivially_true() + } + } + + fn induced_lower_bound(&self, variable_state: &VariableState) -> IntExt { + IntExt::Int(if variable_state.is_true(&self.inner) { + 1 + } else { + Default::default() + }) + } + + fn induced_upper_bound(&self, variable_state: &VariableState) -> IntExt { + IntExt::Int(if variable_state.is_true(&!self.inner) { + 0 + } else { + 1 + }) + } + + fn induced_fixed_value(&self, variable_state: &VariableState) -> Option { + if variable_state.is_true(&self.inner) { + Some(1) + } else if variable_state.is_true(&!self.inner) { + Some(0) + } else { + None + } + } + + fn induced_domain_contains( + &self, + variable_state: &VariableState, + value: i32, + ) -> bool { + if value == 1 && !variable_state.is_true(&!self.inner) { + true + } else { + value == 0 && !variable_state.is_true(&self.inner) + } + } + + fn induced_holes<'this, 'state>( + &'this self, + _variable_state: &'state VariableState, + ) -> impl Iterator + 'state + where + 'this: 'state, + { + std::iter::empty() + } + + fn iter_induced_domain<'this, 'state>( + &'this self, + variable_state: &'state VariableState, + ) -> Option + 'state> + where + 'this: 'state, + { + Some((0..=1).filter(|&value| { + if value == 0 { + !variable_state.is_true(&self.inner) + } else if value == 1 { + !variable_state.is_true(&!self.inner) + } else { + unreachable!() + } + })) + } } impl IntegerVariable for Literal { @@ -112,7 +172,11 @@ impl IntegerVariable for Literal { /// Literal that evaluate to false have a lower bound of 0. /// Unassigned literals have a lower bound of 0. fn lower_bound(&self, assignment: &Assignments) -> i32 { - self.integer_variable.lower_bound(assignment) + if assignment.is_predicate_satisfied(self.inner) { + 1 + } else { + 0 + } } fn lower_bound_at_trail_position( @@ -120,8 +184,11 @@ impl IntegerVariable for Literal { assignment: &Assignments, trail_position: usize, ) -> i32 { - self.integer_variable - .lower_bound_at_trail_position(assignment, trail_position) + if assignment.is_predicate_satisfied_at_trail_position(self.inner, trail_position) { + 1 + } else { + 0 + } } /// Returns the upper bound represented as a 0-1 value. @@ -129,7 +196,11 @@ impl IntegerVariable for Literal { /// Literal that evaluate to false have a upper bound of 0. /// Unassigned literals have a upper bound of 1. fn upper_bound(&self, assignment: &Assignments) -> i32 { - self.integer_variable.upper_bound(assignment) + if assignment.is_predicate_satisfied(!self.inner) { + 0 + } else { + 1 + } } fn upper_bound_at_trail_position( @@ -137,8 +208,11 @@ impl IntegerVariable for Literal { assignment: &Assignments, trail_position: usize, ) -> i32 { - self.integer_variable - .upper_bound_at_trail_position(assignment, trail_position) + if assignment.is_predicate_satisfied_at_trail_position(!self.inner, trail_position) { + 0 + } else { + 1 + } } /// Returns whether the input value, when interpreted as a bool, @@ -147,7 +221,13 @@ impl IntegerVariable for Literal { /// Literals that evaluate to false only contain value 0. /// Unassigned literals contain both values 0 and 1. fn contains(&self, assignment: &Assignments, value: i32) -> bool { - self.integer_variable.contains(assignment, value) + if value == 0 { + !assignment.is_predicate_satisfied(self.inner) + } else if value == 1 { + !assignment.is_predicate_satisfied(!self.inner) + } else { + false + } } fn contains_at_trail_position( @@ -156,40 +236,52 @@ impl IntegerVariable for Literal { value: i32, trail_position: usize, ) -> bool { - self.integer_variable - .contains_at_trail_position(assignment, value, trail_position) + if value == 0 { + !assignment.is_predicate_satisfied_at_trail_position(self.inner, trail_position) + } else if value == 1 { + !assignment.is_predicate_satisfied_at_trail_position(!self.inner, trail_position) + } else { + false + } } fn iterate_domain(&self, assignment: &Assignments) -> impl Iterator { - self.integer_variable.iterate_domain(assignment) + (0..=1).filter(|&value| { + if value == 0 { + !assignment.is_predicate_satisfied(self.inner) + } else if value == 1 { + !assignment.is_predicate_satisfied(!self.inner) + } else { + unreachable!() + } + }) } fn watch_all(&self, watchers: &mut Watchers<'_>, events: EnumSet) { - self.integer_variable.watch_all(watchers, events) + watchers.watch_literal(*self, events) } fn unwatch_all(&self, watchers: &mut Watchers<'_>) { - self.integer_variable.unwatch_all(watchers) + watchers.unwatch_predicate(self.inner); } fn unpack_event(&self, event: OpaqueDomainEvent) -> DomainEvent { - self.integer_variable.unpack_event(event) + event.unwrap() } fn watch_all_backtrack(&self, watchers: &mut Watchers<'_>, events: EnumSet) { - self.integer_variable.watch_all_backtrack(watchers, events) + watchers.watch_literal_backtrack(*self, events) } fn get_holes_at_current_checkpoint( &self, - assignments: &Assignments, + _assignments: &Assignments, ) -> impl Iterator { - self.integer_variable - .get_holes_at_current_checkpoint(assignments) + std::iter::empty() } - fn get_holes(&self, assignments: &Assignments) -> impl Iterator { - self.integer_variable.get_holes(assignments) + fn get_holes(&self, _assignments: &Assignments) -> impl Iterator { + std::iter::empty() } } @@ -197,19 +289,43 @@ impl PredicateConstructor for Literal { type Value = i32; fn lower_bound_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.lower_bound_predicate(bound) + if bound == 1 { + self.inner + } else if bound < 1 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } } fn upper_bound_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.upper_bound_predicate(bound) + if bound == 0 { + !self.inner + } else if bound > 0 { + Predicate::trivially_true() + } else { + Predicate::trivially_false() + } } fn equality_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.equality_predicate(bound) + if bound == 0 { + !self.inner + } else if bound == 1 { + self.inner + } else { + Predicate::trivially_true() + } } fn disequality_predicate(&self, bound: Self::Value) -> Predicate { - self.integer_variable.disequality_predicate(bound) + if bound == 0 { + self.inner + } else if bound == 1 { + !self.inner + } else { + Predicate::trivially_true() + } } } diff --git a/pumpkin-crates/core/src/engine/variables/mod.rs b/pumpkin-crates/core/src/engine/variables/mod.rs index 42837e402..2c0d10bb6 100644 --- a/pumpkin-crates/core/src/engine/variables/mod.rs +++ b/pumpkin-crates/core/src/engine/variables/mod.rs @@ -14,5 +14,6 @@ pub use affine_view::AffineView; pub(crate) use domain_generator_iterator::DomainGeneratorIterator; pub use domain_id::DomainId; pub use integer_variable::IntegerVariable; +pub use integer_variable::IntegerVariableEnum; pub use literal::Literal; pub use transformable_variable::TransformableVariable; diff --git a/pumpkin-crates/core/src/propagation/constructor.rs b/pumpkin-crates/core/src/propagation/constructor.rs index 791627880..bf3100bc7 100644 --- a/pumpkin-crates/core/src/propagation/constructor.rs +++ b/pumpkin-crates/core/src/propagation/constructor.rs @@ -164,7 +164,12 @@ impl PropagatorConstructorContext<'_> { self.update_next_local_id(local_id); - let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + &mut self.state.notification_engine, + &mut self.state.trailed_values, + &self.state.assignments, + ); var.watch_all(&mut watchers, domain_events.events()); } @@ -207,7 +212,12 @@ impl PropagatorConstructorContext<'_> { self.update_next_local_id(local_id); - let mut watchers = Watchers::new(propagator_var, &mut self.state.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + &mut self.state.notification_engine, + &mut self.state.trailed_values, + &self.state.assignments, + ); var.watch_all_backtrack(&mut watchers, domain_events.events()); } diff --git a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs index 5c9d67a48..374b0ebab 100644 --- a/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs +++ b/pumpkin-crates/core/src/propagation/contexts/propagation_context.rs @@ -152,7 +152,12 @@ impl<'a> PropagationContext<'a> { variable: local_id, }; - let mut watchers = Watchers::new(propagator_var, self.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + self.notification_engine, + self.trailed_values, + self.assignments, + ); var.watch_all(&mut watchers, domain_events.events()); } @@ -163,7 +168,12 @@ impl<'a> PropagationContext<'a> { variable: local_id, }; - let mut watchers = Watchers::new(propagator_var, self.notification_engine); + let mut watchers = Watchers::new( + propagator_var, + self.notification_engine, + self.trailed_values, + self.assignments, + ); var.unwatch_all(&mut watchers); } diff --git a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs index b89a6a413..16abf1143 100644 --- a/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-crates/core/src/propagators/nogoods/nogood_propagator.rs @@ -317,6 +317,7 @@ impl Propagator for NogoodPropagator { let reason = Reason::DynamicLazy(watcher.nogood_id.id as u64); let predicate = !context.get_predicate(nogood_predicates[0]); + let result = context.post( predicate, reason, diff --git a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs index 9b58d7fbf..300165525 100644 --- a/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-crates/propagators/src/propagators/arithmetic/linear_less_or_equal.rs @@ -3,6 +3,7 @@ use pumpkin_checking::CheckerVariable; use pumpkin_checking::InferenceChecker; use pumpkin_checking::IntExt; use pumpkin_checking::VariableState; +use pumpkin_core::asserts::pumpkin_assert_extreme; use pumpkin_core::asserts::pumpkin_assert_simple; use pumpkin_core::declare_inference_label; use pumpkin_core::predicate; @@ -75,6 +76,13 @@ where current_bounds.push(context.new_trailed_integer(context.lower_bound(x_i) as i64)); } + pumpkin_assert_extreme!( + lower_bound_left_hand_side + == x.iter().map(|x| context.lower_bound(x) as i64).sum::(), + "Expected {lower_bound_left_hand_side} to be equal to {}", + x.iter().map(|x| context.lower_bound(x)).sum::() + ); + let lower_bound_left_hand_side = context.new_trailed_integer(lower_bound_left_hand_side); LinearLessOrEqualPropagator { @@ -144,10 +152,7 @@ where let old_bound = context.read_trailed_integer(self.current_bounds[index]); let new_bound = context.lower_bound(x_i) as i64; - pumpkin_assert_simple!( - old_bound < new_bound, - "propagator should only be triggered when lower bounds are tightened, old_bound={old_bound}, new_bound={new_bound}" - ); + pumpkin_assert_simple!(new_bound > old_bound); context.write_trailed_integer( self.lower_bound_left_hand_side, @@ -217,10 +222,19 @@ where return Ok(()); } }; + pumpkin_assert_extreme!( + lower_bound_left_hand_side + == self.x.iter().map(|x| context.lower_bound(x)).sum::(), + "Expected {lower_bound_left_hand_side} to be equal to {}\n{:?}", + self.x.iter().map(|x| context.lower_bound(x)).sum::(), + self.x + .iter() + .map(|x| (x, context.lower_bound(x))) + .collect::>() + ); for (i, x_i) in self.x.iter().enumerate() { let bound = self.c - (lower_bound_left_hand_side - context.lower_bound(x_i)); - if context.upper_bound(x_i) > bound { context.post(predicate![x_i <= bound], i, &self.inference_code)?; } diff --git a/pumpkin-solver-py/src/brancher.rs b/pumpkin-solver-py/src/brancher.rs index 0ae40e269..9f7b307c6 100644 --- a/pumpkin-solver-py/src/brancher.rs +++ b/pumpkin-solver-py/src/brancher.rs @@ -7,13 +7,13 @@ use pumpkin_solver::core::containers::HashMap; use pumpkin_solver::core::predicates::Predicate; use pumpkin_solver::core::results::SolutionReference; use pumpkin_solver::core::statistics::StatisticLogger; -use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use crate::variables::IntExpression; pub struct PythonBrancher { - warm_start: WarmStart>, + warm_start: WarmStart, default_brancher: DefaultBrancher, } diff --git a/pumpkin-solver-py/src/constraints/arguments.rs b/pumpkin-solver-py/src/constraints/arguments.rs index ee2425fa8..6257f4cec 100644 --- a/pumpkin-solver-py/src/constraints/arguments.rs +++ b/pumpkin-solver-py/src/constraints/arguments.rs @@ -1,5 +1,4 @@ -use pumpkin_solver::core::variables::AffineView; -use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; use crate::variables::BoolExpression; @@ -13,7 +12,7 @@ pub trait PythonConstraintArg { } impl PythonConstraintArg for IntExpression { - type Output = AffineView; + type Output = IntegerVariableEnum; fn to_solver_constraint_argument(self) -> Self::Output { self.0 diff --git a/pumpkin-solver-py/src/model.rs b/pumpkin-solver-py/src/model.rs index e35b41880..cbf2b35b3 100644 --- a/pumpkin-solver-py/src/model.rs +++ b/pumpkin-solver-py/src/model.rs @@ -156,15 +156,6 @@ impl Model { self.solver.is_inconsistent() } - /// Get an integer variable for the given boolean. - /// - /// The integer is 1 if the boolean is `true`, and 0 if the boolean is `false`. - /// - /// This is deprecated as of 0.3.0. Prefer to use `BoolExpression.as_integer`. - fn boolean_as_integer(&mut self, boolean: BoolExpression) -> IntExpression { - boolean.as_integer() - } - /// Reify a predicate as an explicit boolean expression. /// /// A tag should be provided for this link to be identifiable in the proof. @@ -191,9 +182,6 @@ impl Model { self.solver.new_literal_for_predicate(solver_predicate, tag) }; - self.brancher - .add_domain(*literal.get_integer_variable().inner()); - Ok(literal.into()) } diff --git a/pumpkin-solver-py/src/variables.rs b/pumpkin-solver-py/src/variables.rs index 6e9e81527..d2882a7a8 100644 --- a/pumpkin-solver-py/src/variables.rs +++ b/pumpkin-solver-py/src/variables.rs @@ -1,13 +1,13 @@ use pumpkin_solver::core::predicate; -use pumpkin_solver::core::variables::AffineView; use pumpkin_solver::core::variables::DomainId; +use pumpkin_solver::core::variables::IntegerVariableEnum; use pumpkin_solver::core::variables::Literal; use pumpkin_solver::core::variables::TransformableVariable; use pyo3::prelude::*; #[pyclass(eq, hash, frozen)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct IntExpression(pub AffineView); +pub struct IntExpression(pub IntegerVariableEnum); impl From for IntExpression { fn from(domain_id: DomainId) -> IntExpression { @@ -15,6 +15,18 @@ impl From for IntExpression { } } +impl From for IntExpression { + fn from(value: Literal) -> Self { + IntExpression(value.into()) + } +} + +impl From for IntExpression { + fn from(value: BoolExpression) -> Self { + IntExpression(value.0.into()) + } +} + #[pymethods] impl IntExpression { fn offset(&self, add_offset: i32) -> IntExpression { @@ -32,13 +44,23 @@ pub struct BoolExpression(pub Literal); #[pymethods] impl BoolExpression { - pub fn as_integer(&self) -> IntExpression { - IntExpression(self.0.get_integer_variable()) + pub fn scaled(&self, scale: i32) -> IntExpression { + let int_expr: IntExpression = (*self).into(); + int_expr.scaled(scale) + } + + pub fn offset(&self, offset: i32) -> IntExpression { + let int_expr: IntExpression = (*self).into(); + int_expr.offset(offset) } pub fn negate(&self) -> BoolExpression { BoolExpression(!self.0) } + + pub fn as_expression(&self) -> IntExpression { + (*self).into() + } } impl From for BoolExpression { diff --git a/pumpkin-solver-py/tests/test_constraints.py b/pumpkin-solver-py/tests/test_constraints.py index 8d26295a8..674064084 100644 --- a/pumpkin-solver-py/tests/test_constraints.py +++ b/pumpkin-solver-py/tests/test_constraints.py @@ -1,15 +1,14 @@ -""" -Generate constraints and expressions based on the grammar supported by the API +"""Generate constraints and expressions based on the grammar supported by the API. -Generates linear constraints, special operators and global constraints. -Whenever possible, the script also generates 'boolean as integer' versions of the arguments +Generates linear constraints, special operators and global constraints. Whenever possible, the script also generates +'boolean as integer' versions of the arguments """ from random import randint +import pumpkin_solver import pytest from pumpkin_solver import constraints -import pumpkin_solver def chain(*iterables): @@ -68,15 +67,11 @@ def create_linear_model(request): model = pumpkin_solver.Model() if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 rhs = 1 cons = None @@ -104,15 +99,11 @@ def create_operator_model(request): model = pumpkin_solver.Model() if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 rhs = model.new_integer_variable(-3, 5, name="rhs") cons = None @@ -128,9 +119,7 @@ def create_operator_model(request): if name == "max": cons = constraints.Maximum(args, rhs, model.new_constraint_tag()) if name == "element": - idx = model.new_integer_variable( - -1, 5, name="idx" - ) # sneaky, idx can be out of bounds + idx = model.new_integer_variable(-1, 5, name="idx") # sneaky, idx can be out of bounds cons = constraints.Element(idx, args, rhs, model.new_constraint_tag()) if not cons: @@ -151,16 +140,11 @@ def create_global_model(request): if name == "alldiff": if bool: - args = [ - model.new_boolean_variable(name=f"x[{i}]").as_integer() - for i in range(3) - ] + args = [model.new_boolean_variable(name=f"x[{i}]").as_expression() for i in range(3)] else: args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled or bool: # do scaling (0, -2, 4,...) - args = [ - a.scaled(-2 * i + 1) for i, a in enumerate(args) - ] # TODO: div by zero when scale = 0, fixed with +1 + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 cons = constraints.AllDifferent(args, model.new_constraint_tag()) @@ -182,9 +166,7 @@ def create_global_model(request): start = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] if scaled: start = [a.scaled(-2 * i) for i, a in enumerate(start)] - cons = constraints.Cumulative( - start, duration, demand, capacity, model.new_constraint_tag() - ) + cons = constraints.Cumulative(start, duration, demand, capacity, model.new_constraint_tag()) else: assert False, f"unknown global {name}" @@ -200,9 +182,7 @@ def global_model(request): def make_id(args): name, scaled, bool = args - return " ".join( - ["Scaled" if scaled else "Unscaled", "Boolean" if bool else "Integer", name] - ) + return " ".join(["Scaled" if scaled else "Unscaled", "Boolean" if bool else "Integer", name]) @pytest.mark.parametrize("linear_model", generate_linear(), indirect=True, ids=make_id) diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs index 46b2ac0f4..105dc579c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/collect_domains.rs @@ -5,6 +5,7 @@ use std::rc::Rc; use flatzinc::Annotation; use pumpkin_core::Solver; use pumpkin_core::containers::HashMap; +use pumpkin_core::predicate; use pumpkin_core::variables::DomainId; use pumpkin_solver::core::variables::Literal; @@ -39,7 +40,7 @@ pub(crate) fn run( ) }); - let literal = Literal::new(domain_id); + let literal = Literal::new(predicate!(domain_id >= 1)); if is_output_variable(annos) { context.outputs.push(Output::bool(id, literal)); diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs index 8cc79b4a5..8a4803198 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/context.rs @@ -4,6 +4,7 @@ use std::collections::BTreeSet; use std::rc::Rc; use log::warn; +use pumpkin_core::predicate; use pumpkin_solver::Solver; use pumpkin_solver::core::containers::HashMap; use pumpkin_solver::core::containers::HashSet; @@ -139,7 +140,7 @@ impl CompilationContext<'_> { .variable_map .get(&self.equivalences.representative(identifier)) { - Ok(Literal::new(*domain_id)) + Ok(Literal::new(predicate!(domain_id >= 1))) } else { self.boolean_parameters .get(&self.equivalences.representative(identifier)) diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs index 52d1fc609..04a2ec2a2 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/create_search_strategy.rs @@ -146,10 +146,7 @@ fn create_from_search_strategy( } AnnExpr::Expr(expr) => { let bool_variable_array = context - .resolve_bool_variable_array(expr)? - .iter() - .map(|literal| literal.get_integer_variable()) - .collect::>(); + .resolve_bool_variable_array(expr)?; match values { AnnExpr::Expr(expr) => { diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs index 2b023d80d..cda074f82 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/define_variable_arrays.rs @@ -11,6 +11,7 @@ use flatzinc::Expr; use flatzinc::IntExpr; use flatzinc::SetExpr; use flatzinc::SetLiteralExpr; +use pumpkin_core::predicate; use pumpkin_core::variables::Literal; use super::context::CompilationContext; @@ -49,7 +50,7 @@ pub(crate) fn run( .copied() .expect("referencing undefined boolean variable"); - Literal::new(domain_id) + Literal::new(predicate!(domain_id >= 1)) } }) .collect(), diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index e72963c14..ae444b171 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -2,6 +2,7 @@ use std::rc::Rc; +use pumpkin_core::variables::IntegerVariableEnum; use pumpkin_propagators::disjunctive::ArgDisjunctiveTask; use pumpkin_solver::core::constraints::Constraint; use pumpkin_solver::core::constraints::NegatableConstraint; @@ -648,11 +649,13 @@ fn compile_bool2int( let a = context.resolve_bool_variable(&exprs[0])?; let b = context.resolve_integer_variable(&exprs[1])?; - Ok( - pumpkin_constraints::binary_equals(a.get_integer_variable(), b.scaled(1), constraint_tag) - .post(context.solver) - .is_ok(), + Ok(pumpkin_constraints::binary_equals( + IntegerVariableEnum::Literal(a.scaled(1)), + IntegerVariableEnum::DomainId(b.scaled(1)), + constraint_tag, ) + .post(context.solver) + .is_ok()) } fn compile_bool_or(