diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index acfa9b38c0b..fe2ed28cf1f 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include #include #include @@ -20,7 +21,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -138,6 +141,16 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept { std::swap(a.outputs_, b.outputs_); std::swap(a.io_alias_, b.io_alias_); + + // Swap per-Fusion special values (Phase 2) + std::swap(a.zero_val_, b.zero_val_); + std::swap(a.one_val_, b.one_val_); + std::swap(a.true_val_, b.true_val_); + std::swap(a.false_val_, b.false_val_); + std::swap(a.magic_zero_val_, b.magic_zero_val_); + + std::swap(a.axioms_, b.axioms_); + std::swap(a.metadata_, b.metadata_); } std::unique_ptr Fusion::segment( @@ -151,6 +164,24 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container()); + // Remap cached special val pointers through the cloner + if (from->zero_val_) { + to->zero_val_ = ir_cloner.clone(from->zero_val_); + } + if (from->one_val_) { + to->one_val_ = ir_cloner.clone(from->one_val_); + } + if (from->true_val_) { + to->true_val_ = ir_cloner.clone(from->true_val_); + } + if (from->false_val_) { + to->false_val_ = ir_cloner.clone(from->false_val_); + } + if (from->magic_zero_val_) { + to->magic_zero_val_ = + ir_cloner.clone(from->magic_zero_val_)->as(); + } + for (auto val : from->vals()) { ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); @@ -199,6 +230,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_; + if (from->axioms_ != nullptr) { + to->axioms_ = std::make_unique>(); + to->axioms_->reserve(from->axioms_->size()); + for (auto pred : *from->axioms_) { + to->axioms_->push_back(ir_cloner.clone(pred)); + } + } + + for (auto& [key, val_expr] : from->metadata_) { + to->metadata_[ir_cloner.clone(key)] = std::make_pair( + ir_cloner.clone(val_expr.first), ir_cloner.clone(val_expr.second)); + } + if (from->all_tvs_ptr_ != nullptr) { to->all_tvs_ptr_ = std::make_unique>(); to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size()); @@ -265,6 +309,17 @@ void Fusion::clear() noexcept { managed_data_.clear(); managed_named_data_.clear(); + // Reset per-Fusion special value caches (the vals themselves are owned by + // ir_container and were already destroyed by ir_container()->clear() above). + zero_val_ = nullptr; + one_val_ = nullptr; + true_val_ = nullptr; + false_val_ = nullptr; + magic_zero_val_ = nullptr; + + axioms_.reset(); + metadata_.clear(); + invalidateTvsAndUses(); is_during_update_uses_ = false; @@ -298,6 +353,12 @@ void Fusion::removeExpr(Expr* expr) { void Fusion::removeVal(Val* val) { assertInContainer(val, "Cannot remove val "); + // Don't remove cached special vals — they are lazily created singletons + if (val == zero_val_ || val == one_val_ || val == true_val_ || + val == false_val_ || val == magic_zero_val_) { + return; + } + NVF_CHECK( !val->isFusionInput(), "Cannot remove val as it is an input of the fusion."); @@ -341,6 +402,57 @@ void Fusion::removeVal(Val* val) { invalidateTvsAndUses(); } +void Fusion::removeStatementsCreatedAfter( + int64_t num_exprs_before, + int64_t num_vals_before) { + auto* c = ir_container(); + + NVF_ERROR( + c->exprs_up_.size() == c->exprs_.size(), + "exprs_up_ (size ", + c->exprs_up_.size(), + ") and exprs_ (size ", + c->exprs_.size(), + ") are out of sync."); + NVF_ERROR( + std::ssize(c->exprs_up_) >= num_exprs_before, + "exprs_up_ size (", + std::ssize(c->exprs_up_), + ") is less than num_exprs_before (", + num_exprs_before, + ")."); + + // Remove expressions before values because we need to change Val::uses_. + while (std::ssize(c->exprs_up_) > num_exprs_before) { + Expr* e = c->exprs_up_.back().get(); + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->exprs_.erase(e); + c->exprs_up_.pop_back(); + } + + // Null out any special value caches that point to vals about to be destroyed. + // This prevents dangling pointers when special vals are lazily created inside + // a StatementGuard scope. + while (std::ssize(c->vals_up_) > num_vals_before) { + Val* v = c->vals_up_.back().get(); + if (v == zero_val_) { + zero_val_ = nullptr; + } else if (v == one_val_) { + one_val_ = nullptr; + } else if (v == true_val_) { + true_val_ = nullptr; + } else if (v == false_val_) { + false_val_ = nullptr; + } else if (v == magic_zero_val_) { + magic_zero_val_ = nullptr; + } + c->vals_.erase(v); + c->vals_up_.pop_back(); + } +} + void Fusion::addInput(Val* input) { assertInContainer(input, "Cannot register input "); @@ -686,6 +798,105 @@ void Fusion::printTransforms() { t_exprs.handle(this); } +Val* Fusion::zeroVal() { + if (!zero_val_) { + zero_val_ = IrBuilder::createInContainer(this, 0L, DataType::Index); + } + return zero_val_; +} + +Val* Fusion::oneVal() { + if (!one_val_) { + one_val_ = IrBuilder::createInContainer(this, 1L, DataType::Index); + } + return one_val_; +} + +Val* Fusion::falseVal() { + if (!false_val_) { + false_val_ = IrBuilder::createInContainer(this, false, DataType::Bool); + } + return false_val_; +} + +Val* Fusion::trueVal() { + if (!true_val_) { + true_val_ = IrBuilder::createInContainer(this, true, DataType::Bool); + } + return true_val_; +} + +NamedScalar* Fusion::magicZeroVal() { + if (!magic_zero_val_) { + magic_zero_val_ = IrBuilder::createInContainer( + this, kMagicZeroName, DataType::Index); + } + return magic_zero_val_; +} + +Val* Fusion::zeroVal(DataType dtype) { + if (dtype == DataType::Index) { + return zeroVal(); + } else if (isBooleanType(dtype)) { + return falseVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 0L, dtype); + } +} + +Val* Fusion::oneVal(DataType dtype) { + if (dtype == DataType::Index) { + return oneVal(); + } else if (isBooleanType(dtype)) { + return trueVal(); + } else { + // NOTE: this does not cache values + return IrBuilder::createInContainer(this, 1L, dtype); + } +} + +Val* Fusion::metadataOf(Val* v) { + if (metadata_.count(v) == 0) { + auto metadata_val = + IrBuilder::createInContainer(this, metaDataTypeOf(v)); + auto metadata_expr = + IrBuilder::createInContainer(this, metadata_val, v); + metadata_[v] = std::make_pair(metadata_val, metadata_expr); + } + return metadata_.at(v).first; +} + +const std::vector& Fusion::axioms() { + if (!axioms_) { + axioms_ = std::make_unique>(); + axioms_->reserve(kParallelTypeThreads.size() * 3); + auto zero = zeroVal(); + for (auto p : kParallelTypeThreads) { + auto pidx = NamedScalar::getParallelIndex(p); + auto pdim = NamedScalar::getParallelDim(p); + axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); + axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); + axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); + } + } + return *axioms_; +} + +void Fusion::assumePositive(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); +} + +void Fusion::assumeNonNegative(Val* val) { + NVF_ERROR(inContainer(val)); + // Lazy init axioms, then add the assumption + axioms(); + axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); +} + void Fusion::registerVal(Val* val) { if (inContainer(val)) { return; diff --git a/csrc/fusion.h b/csrc/fusion.h index cb3a555e814..4fbf71c988f 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -60,6 +60,7 @@ namespace nvfuser { //! checks. class Fusion; +class NamedScalar; class TensorView; class SegmentCandidateFinder; @@ -70,7 +71,7 @@ class DynamicTransformConcretizationInfo; // Set the enum base to `int` so it can be safely serialized as a part of // serde::InputOutputAlias. -enum class AllocationType : int { +enum class AllocationType : int { // NOLINT(performance-enum-size) New, // Allocate a new buffer // Reuse the buffer allocated to `aliased_io`. For example, the tensor storing // BatchNorm's running mean. The output EMA is updated in place. @@ -87,7 +88,7 @@ enum class AllocationType : int { std::ostream& operator<<(std::ostream& os, AllocationType); -enum class OutputVisibility : int { +enum class OutputVisibility : int { // NOLINT(performance-enum-size) kHidden, kVisible, }; @@ -550,63 +551,31 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container()->numExprs(); } - int64_t numVals(bool include_shortcuts) const noexcept { - return ir_container()->numVals(include_shortcuts); + int64_t numVals() const noexcept { + return ir_container()->numVals(); } // Shortcut values (frequently used constants) - Val* zeroVal() { - return ir_container()->zeroVal(); - } - - Val* oneVal() { - return ir_container()->oneVal(); - } - - Val* falseVal() { - return ir_container()->falseVal(); - } - - Val* trueVal() { - return ir_container()->trueVal(); - } - - NamedScalar* magicZeroVal() { - return ir_container()->magicZeroVal(); - } - - Val* zeroVal(DataType dtype) { - return ir_container()->zeroVal(dtype); - } + Val* zeroVal(); + Val* oneVal(); + Val* falseVal(); + Val* trueVal(); + NamedScalar* magicZeroVal(); + Val* zeroVal(DataType dtype); + Val* oneVal(DataType dtype); - Val* oneVal(DataType dtype) { - return ir_container()->oneVal(dtype); - } - - Val* metadataOf(Val* val) { - return ir_container()->metadataOf(val); - } + Val* metadataOf(Val* val); // Axioms (CUDA programming assumptions) - const std::vector& axioms() { - return ir_container()->axioms(); - } + const std::vector& axioms(); - void assumePositive(Val* val) { - ir_container()->assumePositive(val); - } - - void assumeNonNegative(Val* val) { - ir_container()->assumeNonNegative(val); - } + void assumePositive(Val* val); + void assumeNonNegative(Val* val); // Statement removal void removeStatementsCreatedAfter( int64_t num_exprs_before, - int64_t num_vals_before) { - ir_container()->removeStatementsCreatedAfter( - num_exprs_before, num_vals_before); - } + int64_t num_vals_before); protected: friend SegmentCandidateFinder; @@ -668,6 +637,16 @@ class NVF_API Fusion : public PolymorphicBase { inline static const std::string exact_mappings_key = "exact_mappings"; std::unique_ptr ir_container_; + + Val* zero_val_ = nullptr; + Val* one_val_ = nullptr; + Val* true_val_ = nullptr; + Val* false_val_ = nullptr; + NamedScalar* magic_zero_val_ = nullptr; + + std::unique_ptr> axioms_; + + std::unordered_map> metadata_; }; // Template implementations for Fusion::manage() that use IrCloner diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 3c54966c87d..c91805e5f11 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -18,9 +18,8 @@ namespace nvfuser { //! Return values in insertion order const std::deque IrContainer::deterministic_vals() const noexcept { std::deque vals_deque; - std::transform( - vals_up_.begin(), - vals_up_.end(), + std::ranges::transform( + vals_up_, std::back_inserter(vals_deque), [](const std::unique_ptr& val_up) { return val_up.get(); }); return vals_deque; @@ -29,9 +28,8 @@ const std::deque IrContainer::deterministic_vals() const noexcept { //! Return expression in insertion order const std::deque IrContainer::deterministic_exprs() const noexcept { std::deque exprs_deque; - std::transform( - exprs_up_.begin(), - exprs_up_.end(), + std::ranges::transform( + exprs_up_, std::back_inserter(exprs_deque), [](const std::unique_ptr& expr_up) { return expr_up.get(); }); return exprs_deque; @@ -42,9 +40,8 @@ const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { std::unordered_map vals_map; int64_t count = 0; - std::transform( - vals_up_.begin(), - vals_up_.end(), + std::ranges::transform( + vals_up_, std::inserter(vals_map, vals_map.end()), [&count](const std::unique_ptr& val_up) { return std::make_pair(val_up.get(), count++); @@ -57,9 +54,8 @@ const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { std::unordered_map exprs_map; int64_t count = 0; - std::transform( - exprs_up_.begin(), - exprs_up_.end(), + std::ranges::transform( + exprs_up_, std::inserter(exprs_map, exprs_map.end()), [&count](const std::unique_ptr& expr_up) { return std::make_pair(expr_up.get(), count++); @@ -80,25 +76,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { std::swap(a.val_type_name_map_, b.val_type_name_map_); std::swap(a.expr_name_counter_, b.expr_name_counter_); - std::swap(a.metadata_, b.metadata_); - std::swap(a.parent_, b.parent_); - - std::swap(a.zero_val_, b.zero_val_); - std::swap(a.one_val_, b.one_val_); - std::swap(a.true_val_, b.true_val_); - std::swap(a.false_val_, b.false_val_); - std::swap(a.magic_zero_val_, b.magic_zero_val_); - std::swap(a.axioms_, b.axioms_); } IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->clear(); + IrCloner ir_cloner(to->parent()); // Copy values in deterministic order - // deterministic_vals can contain special values like one_val_, zero_val_, etc - // that are not registered in the container. for (auto val : from->deterministic_vals()) { if (from->vals().count(val) > 0) { to->vals_.insert(ir_cloner.clone(val)); @@ -115,15 +101,6 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; - if (from->axioms_ != nullptr) { - to->axioms_ = std::make_unique>(); - for (auto pred : *from->axioms_) { - to->axioms_->push_back(ir_cloner.clone(pred)); - } - } - - to->metadata_ = ir_cloner.clone(from->metadata_); - return ir_cloner; } @@ -137,9 +114,8 @@ void IrContainer::removeExpr(Expr* expr) { NVF_ERROR( exprs_.find(expr) != exprs_.end(), "Wanted to remove an expression but it doesn't exist in this container."); - auto expr_in_deque = std::find_if( - exprs_up_.begin(), - exprs_up_.end(), + auto expr_in_deque = std::ranges::find_if( + exprs_up_, [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); NVF_ERROR( @@ -153,20 +129,12 @@ void IrContainer::removeExpr(Expr* expr) { //! Completely remove val from the fusion, break all dependencies associated //! with it void IrContainer::removeVal(Val* val) { - // Don't remove shortcuts - if (val == true_val_.get() || val == false_val_.get() || - val == one_val_.get() || val == zero_val_.get() || - val == magic_zero_val_.get()) { - return; - } - NVF_ERROR( vals_.find(val) != vals_.end(), "Wanted to remove a value but it doesn't exist in this container."); - auto val_in_deque = std::find_if( - vals_up_.begin(), vals_up_.end(), [val](std::unique_ptr& val_up) { - return val_up.get() == val; - }); + auto val_in_deque = std::ranges::find_if( + vals_up_, + [val](std::unique_ptr& val_up) { return val_up.get() == val; }); NVF_ERROR( val_in_deque != vals_up_.end(), @@ -206,9 +174,7 @@ void IrContainer::clear() noexcept { vals_up_.clear(); exprs_.clear(); exprs_up_.clear(); - axioms_.reset(); val_type_name_map_.clear(); - metadata_.clear(); expr_name_counter_ = 0; } @@ -244,155 +210,4 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } -// Shortcuts for frequently used vals -Val* IrContainer::zeroVal() { - if (!zero_val_) { - auto zero_val = - IrBuilder::createInContainer(this->parent(), 0L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == zero_val); - zero_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return zero_val_.get(); -} - -Val* IrContainer::zeroVal(DataType dtype) { - if (dtype == DataType::Index) { - return zeroVal(); - } else if (isBooleanType(dtype)) { - return falseVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 0L, dtype); - } -} - -Val* IrContainer::oneVal() { - if (!one_val_) { - auto one_val = - IrBuilder::createInContainer(this->parent(), 1L, DataType::Index); - NVF_ERROR(vals_up_.back().get() == one_val); - one_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return one_val_.get(); -} - -Val* IrContainer::oneVal(DataType dtype) { - if (dtype == DataType::Index) { - return oneVal(); - } else if (isBooleanType(dtype)) { - return trueVal(); - } else { - // NOTE: this does not cache values - return IrBuilder::createInContainer(this->parent(), 1L, dtype); - } -} - -Val* IrContainer::falseVal() { - if (!false_val_) { - auto false_val = IrBuilder::createInContainer( - this->parent(), false, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == false_val); - false_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return false_val_.get(); -} - -Val* IrContainer::trueVal() { - if (!true_val_) { - auto true_val = - IrBuilder::createInContainer(this->parent(), true, DataType::Bool); - NVF_ERROR(vals_up_.back().get() == true_val); - true_val_ = std::unique_ptr(vals_up_.back().release()); - vals_up_.pop_back(); - } - return true_val_.get(); -} - -NamedScalar* IrContainer::magicZeroVal() { - if (!magic_zero_val_) { - auto magic_zero = - IrBuilder::create(kMagicZeroName, DataType::Index); - NVF_ERROR(vals_up_.back().get() == magic_zero); - magic_zero_val_ = std::unique_ptr( - vals_up_.back().release()->as()); - vals_up_.pop_back(); - } - return magic_zero_val_.get(); -} - -Val* IrContainer::metadataOf(Val* v) { - if (metadata_.count(v) == 0) { - auto metadata_val = - IrBuilder::createInContainer(this->parent(), metaDataTypeOf(v)); - auto metadata_expr = IrBuilder::createInContainer( - this->parent(), metadata_val, v); - metadata_[v] = std::make_pair(metadata_val, metadata_expr); - } - return metadata_.at(v).first; -} - -void IrContainer::lazyInitAxioms() { - if (!axioms_) { - axioms_ = std::make_unique>(); - axioms_->reserve(kParallelTypeThreads.size() * 3); - auto zero = zeroVal(); - for (auto p : kParallelTypeThreads) { - auto pidx = NamedScalar::getParallelIndex(p); - auto pdim = NamedScalar::getParallelDim(p); - axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero)); - axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero)); - axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); - } - } -} - -void IrContainer::assumePositive(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); -} - -void IrContainer::assumeNonNegative(Val* val) { - NVF_ERROR(val->container() == this->parent()); - lazyInitAxioms(); - axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); -} - -void IrContainer::removeStatementsCreatedAfter( - int64_t prev_num_exprs, - int64_t prev_num_vals) { - NVF_ERROR( - exprs_up_.size() == exprs_.size(), - "exprs_up_ (size ", - exprs_up_.size(), - ") and exprs_ (size ", - exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(exprs_up_) >= prev_num_exprs, - "exprs_up_ size (", - std::ssize(exprs_up_), - ") is less than prev_num_exprs (", - prev_num_exprs, - ")."); - - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(exprs_up_) > prev_num_exprs) { - Expr* e = exprs_up_.back().get(); - for (Val* in : e->inputs()) { - in->removeUse(e); - } - exprs_.erase(e); - exprs_up_.pop_back(); - } - - while (std::ssize(vals_up_) > prev_num_vals) { - vals_.erase(vals_up_.back().get()); - vals_up_.pop_back(); - } -} - } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index e361b8743ee..899f2f26439 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -80,30 +80,10 @@ class IrContainer { return std::ssize(exprs_); } - // When include_shortcuts is true, it will count the shortcuts like true_val_. - int64_t numVals(bool include_shortcuts) const noexcept { - return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_); + int64_t numVals() const noexcept { + return std::ssize(vals_up_); } - // Shortcuts for frequently used vals - NVF_API Val* zeroVal(); - NVF_API Val* oneVal(); - Val* falseVal(); - Val* trueVal(); - NamedScalar* magicZeroVal(); - NVF_API Val* zeroVal(DataType dtype); - NVF_API Val* oneVal(DataType dtype); - Val* metadataOf(Val*); - - // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x - const std::vector& axioms() { - lazyInitAxioms(); - return *axioms_; - } - - void assumePositive(Val* val); - void assumeNonNegative(Val* val); - protected: static IrCloner copy(const IrContainer* from, IrContainer* to); @@ -137,20 +117,8 @@ class IrContainer { void clear() noexcept; - void lazyInitAxioms(); - friend class StatementGuard; - // A simple garbage collection mechanism to remove all Exprs and Vals that - // were created after a certain point. This is useful for analysis that - // creates new Exprs and Vals in the container and wants to clean up after - // itself. - // - // Used by StatementGuard only. - void removeStatementsCreatedAfter( - int64_t prev_num_exprs, - int64_t prev_num_vals); - // Deque of unique pointer is the memory owning data structure std::deque> vals_up_; @@ -171,22 +139,6 @@ class IrContainer { // Expression names counter StmtNameType expr_name_counter_ = 0; - // Manually store some persistent, frequently used nodes. It's very - // challenging to do this anything but manually as detecting when a container - // may or may not have one of these vals is tricky. Specifically because if - // the container doesn't own it, it's hard to understand from the outside if - // the node may have been removed then re-registered. It could also be tricky - // to know when we're using a different container as in FusionCopy_test - // demonstrates deleting then creating containers can result in the same - // pointer for the container. - std::unique_ptr true_val_; - std::unique_ptr false_val_; - std::unique_ptr one_val_; - std::unique_ptr zero_val_; - std::unique_ptr magic_zero_val_; - std::unique_ptr> axioms_; - std::unordered_map> metadata_; - public: Fusion* parent() const { NVF_ERROR( diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 0ad664c9d71..fa8f5f51208 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -46,7 +46,7 @@ std::vector normalizeNew2Old( const std::vector& new2old_in, int64_t ndims) { NVF_CHECK( - (int64_t)new2old_in.size() == ndims, + std::cmp_equal(new2old_in.size(), ndims), "There must be a transpose mapping for each dimension in domain"); // Canonicalize dimensions by wrapping each dim for the given ndims @@ -74,7 +74,8 @@ std::vector normalizeNew2Old( // Error out if duplicate values are found. NVF_CHECK( - (int64_t)new2old.size() == ndims && old_pos_set.size() == new2old.size(), + std::cmp_equal(new2old.size(), ndims) && + old_pos_set.size() == new2old.size(), "Duplicate entries in transformation map."); // END VALIDATION CHECKS @@ -196,7 +197,9 @@ struct SubstituteInExpr : public OptOutMutator { private: explicit SubstituteInExpr(Val* reference, Val* substitute) { - mutations_[reference] = substitute; + if (reference != substitute) { + mutations_[reference] = substitute; + } } private: @@ -532,6 +535,7 @@ class ValReplacementMutator : public OptOutMutator { expr->outputs().begin(), expr->outputs().end()); } + // NOLINTNEXTLINE(bugprone-nondeterministic-pointer-iteration-order) for (auto input : inputs) { outputs.erase(input); } @@ -1055,6 +1059,7 @@ CompareDomainResult compareDomains( return v->as()->getIterType() == IterType::Symbolic; }; std::vector ids_to_remove; + // NOLINTNEXTLINE(bugprone-nondeterministic-pointer-iteration-order) for (Val* id : frontier) { if (is_symb(id) && dom1_set.count(id)) { ids_to_remove.push_back(id); diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 717ed11d8e2..4575bb59076 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numVals(/*include_shortcuts=*/false)) {} + prev_num_vals_(fusion_->numVals()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_); diff --git a/tests/cpp/test_indexing.cpp b/tests/cpp/test_indexing.cpp index 5ca0ee83325..57019c39db8 100644 --- a/tests/cpp/test_indexing.cpp +++ b/tests/cpp/test_indexing.cpp @@ -860,9 +860,9 @@ TEST_F(IndexingTest, Reshape) { // to provide the extent of the group. However, since everything // should be deterministic, string match should also work. return std::string( - "( ( ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) / 25 ) * 25 " + "( ( ( ( ( i113 * 20 ) + ( ( i114 * 10 ) + i115 ) ) / 25 ) * 25 " ") " - "+ ( ( ( i114 * 20 ) + ( ( i115 * 10 ) + i116 ) ) % 25 ) )"); + "+ ( ( ( i113 * 20 ) + ( ( i114 * 10 ) + i115 ) ) % 25 ) )"); } default: return std::string(); diff --git a/tests/cpp/test_statement_guard.cpp b/tests/cpp/test_statement_guard.cpp index 9264a065535..704d6714de6 100644 --- a/tests/cpp/test_statement_guard.cpp +++ b/tests/cpp/test_statement_guard.cpp @@ -51,4 +51,55 @@ TEST_F(StatementGuardTest, ExecuteAfterGuard) { executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); } +// Regression test: special vals lazily created inside a StatementGuard scope +// must not become dangling pointers after the guard rolls back. +TEST_F(StatementGuardTest, LazySpecialValsNotDangling) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeContigTensor(1); + fusion->addInput(in); + TensorView* out = set(in); + fusion->addOutput(out); + + // Force lazy creation of trueVal/falseVal inside a StatementGuard scope. + // This reproduces the bug where haveDifferentShardings calls simplifyExpr + // inside a StatementGuard, which can lazily create special vals that then + // become dangling pointers when the guard rolls back. + { + StatementGuard sg(fusion.get()); + // Directly trigger lazy creation of trueVal and falseVal + fusion->trueVal(); + fusion->falseVal(); + fusion->oneVal(); + } + + // After the guard, the special vals should still be valid (re-created if the + // originals were destroyed by the guard's rollback). + Val* z = fusion->zeroVal(); + Val* o = fusion->oneVal(); + Val* t = fusion->trueVal(); + Val* f = fusion->falseVal(); + EXPECT_NE(z, nullptr); + EXPECT_NE(o, nullptr); + EXPECT_NE(t, nullptr); + EXPECT_NE(f, nullptr); + EXPECT_TRUE(z->isZeroInt()); + EXPECT_TRUE(o->isOneInt()); + EXPECT_TRUE(t->isTrue()); + EXPECT_TRUE(f->isFalse()); + + // The fusion should still be executable + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor in_tensor = at::randn({8}, at::device(at::kCUDA)); + auto out_tensors = executor_cache.runFusionWithInputs({in_tensor}); + ASSERT_EQ(out_tensors.size(), 1); + testValidate( + executor_cache.fusion(), + {out_tensors[0].as()}, + {in_tensor}, + __LINE__, + __FILE__); +} + } // namespace nvfuser