diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b95644d843f..9cfa55c10fa 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -137,6 +137,11 @@ struct Fusion::ContainerMutator { c->vals_.insert(val); c->per_fusion_vals_[self].insert(val); val->setName(IrContainerPasskey(), self->getValName(val->vtype())); + + // Seed owning_fusions_ with the registering Fusion (original creator). + if (val->owning_fusions_.empty()) { + val->owning_fusions_.push_back(self); + } } static void registerExpr(Fusion* self, Expr* expr) { @@ -212,7 +217,8 @@ struct Fusion::ContainerMutator { Expr* e = c->exprs_up_.back().get(); NVF_ERROR( c->per_fusion_exprs_[self].count(e) > 0, - "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + "removeStatementsCreatedAfter: tail expr belongs to another " + "Fusion"); for (Val* out : e->outputs()) { out->setDefinition(nullptr); } @@ -457,10 +463,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { } // Wire up definitions and uses on cloned vals in deterministic order - // to ensure exprs are inserted into exprs_up_ deterministically + // to ensure exprs are inserted into exprs_up_ deterministically. + // Skip reused vals (shared scalars) — their definition/uses belong to + // the source Fusion and must not be overwritten. for (auto val : from->deterministic_vals()) { - ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_)); - ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); + auto* cloned = ir_cloner.clone(val); + if (cloned == val) { + continue; // reused (shared scalar) — don't rewire + } + cloned->setDefinition(ir_cloner.clone(val->definition_)); + cloned->setUses(ir_cloner.clone(val->uses_)); } // Sync per-Fusion name counters from source to dest. diff --git a/csrc/fusion.h b/csrc/fusion.h index c4436e11747..d6d2a3aa8f1 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -591,6 +591,7 @@ class NVF_API Fusion : public PolymorphicBase { friend SegmentCandidateFinder; friend SegmentedFusion; friend class TranslateApplicableWelford; + friend class IrCloner; friend Val; //! Constructor that shares an existing container. Creates an empty Fusion diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 0736b560078..fca03e6d1fd 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1801,8 +1801,8 @@ std::pair> SegmentedFusion::makeFusion( SegmentedGroup* sg) const { // TODO Optimize cloning step by only copying values and expressions between // the fusion segment's inputs and outputs. - auto fusion_segment = std::unique_ptr( - new Fusion(completeFusion()->ir_container_ptr())); + auto fusion_segment = + std::unique_ptr(new Fusion(completeFusion()->ir_container_ptr())); IrCloner complete_to_segment_map = Fusion::copy(completeFusion(), fusion_segment.get()); diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 9ea67cfea86..df2959275c0 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -90,6 +90,25 @@ kir::Kernel* Statement::kernel() const { NVFUSER_DEFINE_CLONE(Val) +bool Val::isOwnedBy(const Fusion* f) const { + return std::find(owning_fusions_.begin(), owning_fusions_.end(), f) != + owning_fusions_.end(); +} + +void Val::addOwningFusion(Fusion* f) { + if (!isOwnedBy(f)) { + owning_fusions_.push_back(f); + } +} + +bool Val::removeOwningFusion(Fusion* f) { + auto it = std::find(owning_fusions_.begin(), owning_fusions_.end(), f); + if (it != owning_fusions_.end()) { + owning_fusions_.erase(it); + } + return owning_fusions_.empty(); +} + void Val::addDependency(Val* dependency) { NVF_ERROR(dependency != nullptr); diff --git a/csrc/ir/base_nodes.h b/csrc/ir/base_nodes.h index c7d359dbae1..823ee61f563 100644 --- a/csrc/ir/base_nodes.h +++ b/csrc/ir/base_nodes.h @@ -411,6 +411,24 @@ class NVF_API Val : public Statement { definition_ = expr; } + // Multi-owner tracking for Phase 3 scalar sharing. + // owning_fusions_[0] = original creator (by convention, set at registration). + // Grows to 2+ only for shared scalars (IrCloner reuse path). + bool isShared() const { + return owning_fusions_.size() > 1; + } + + bool isOwnedBy(const Fusion* f) const; + + void addOwningFusion(Fusion* f); + + // Remove an owning Fusion. Returns true if this was the last owner. + bool removeOwningFusion(Fusion* f); + + const std::vector& owningFusions() const { + return owning_fusions_; + } + NVFUSER_DECLARE_CLONE protected: @@ -454,6 +472,9 @@ class NVF_API Val : public Statement { // welford operations. DataType dtype_; + // Tracks all Fusions that own this Val. Seeded at registration. + std::vector owning_fusions_; + // Following is managed by Fusion and can change. bool is_fusion_input_ = false; bool is_fusion_output_ = false; diff --git a/csrc/ir/cloner.cpp b/csrc/ir/cloner.cpp index c71c04f082c..f2ca85dfad9 100644 --- a/csrc/ir/cloner.cpp +++ b/csrc/ir/cloner.cpp @@ -24,21 +24,48 @@ Statement* IrCloner::clone(const Statement* statement) { return nullptr; } - // Have we already cloned this node? + // Step 1: Cache check — already cloned or reused const auto it = clones_map_.find(statement); if (it != clones_map_.end()) { return it->second; - } else { - auto new_node = handle(statement); - - // The base cloning constructor (Statement) should have - // registered the new node. Failure to do so indicates - // that something went horribly wrong. - NVF_ERROR(new_node != nullptr); - NVF_ERROR(clones_map_[statement] == new_node); + } - return new_node; + // Step 2: Scalar reuse — share instead of clone when src and dest Fusions + // use the same IrContainer. Only symbolic leaf scalars qualify (no + // definition, no concrete value). Constants and special vals are excluded + // by the !hasValue() check. + if (statement->isVal()) { + const Val* val = statement->as(); + if (val->isScalar() && val->definition() == nullptr && + !val->value().hasValue()) { + Fusion* src_fusion = val->container(); + if (src_fusion != nullptr && src_fusion != ir_container_ && + src_fusion->ir_container() == ir_container_->ir_container()) { + Val* reused = const_cast(val); + + // (a) Cache so downstream Expr clones resolve this input + clones_map_[statement] = reused; + + // (b) Register with dest Fusion's per-Fusion tracking + auto* c = ir_container_->ir_container(); + { + std::unique_lock lock(c->mutex_); + c->per_fusion_vals_[ir_container_].insert(reused); + } + + // (c) Track ownership for lifetime management + reused->addOwningFusion(ir_container_); + + return reused; + } + } } + + // Step 3: Full clone (unchanged) + auto new_node = handle(statement); + NVF_ERROR(new_node != nullptr); + NVF_ERROR(clones_map_[statement] == new_node); + return new_node; } void IrCloner::registerClone(const Statement* src, Statement* clone) { diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index d33a8af4eef..b40999b412a 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -153,7 +153,8 @@ int64_t IrContainer::numVals() const noexcept { void IrContainer::addFusion(Fusion* fusion) { std::unique_lock lock(mutex_); sharing_fusions_.insert(fusion); - per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration + per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs + // during concurrent val/expr registration per_fusion_exprs_[fusion]; } @@ -225,8 +226,13 @@ void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { const auto& owned = vals_it->second; std::erase_if(vals_up_, [&](const std::unique_ptr& v) { if (owned.count(v.get()) > 0) { + // Multi-owner guard: only free if this is the last owning Fusion. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + if (!v->removeOwningFusion(const_cast(fusion))) { + return false; // other Fusions still own this Val — keep alive + } vals_.erase(v.get()); - return true; + return true; // last owner gone → Val freed } return false; }); diff --git a/csrc/ir/container.h b/csrc/ir/container.h index a9555ae3305..c8824f8c954 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -81,6 +81,7 @@ class IrContainer { protected: // Let Fusion access IrContainer internals (mutex_, fields, Impl helpers) friend class Fusion; + friend class IrCloner; mutable std::shared_mutex mutex_; diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 22484f1b859..cb630ad6122 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -157,7 +157,13 @@ void IterVisitor::traverseBetween( if (to.empty()) { return; } - Fusion* fusion = to.front()->fusion(); + // Use the active FusionGuard rather than deriving the Fusion from a val. + // This avoids calling fusion() on shared scalars whose ir_container_ + // points to the original creator Fusion, not the current traversal target. + Fusion* fusion = FusionGuard::getCurFusion(); + if (fusion == nullptr) { + fusion = to.front()->fusion(); + } FusionGuard fg(fusion); std::unordered_set visited; @@ -468,7 +474,10 @@ void BackwardVisitor::traverseTo( if (from.empty()) { return; } - Fusion* fusion = from.front()->fusion(); + Fusion* fusion = FusionGuard::getCurFusion(); + if (fusion == nullptr) { + fusion = from.front()->fusion(); + } FusionGuard fg(fusion); // Reset members