Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ struct Fusion::ContainerMutator {
c->vals_.insert(val);
c->per_fusion_vals_[self].insert(val);
val->setName(IrContainerPasskey(), self->getValName(val->vtype()));

// Seed owning_fusions_ with the registering Fusion (original creator).
if (val->owning_fusions_.empty()) {
val->owning_fusions_.push_back(self);
}
}

static void registerExpr(Fusion* self, Expr* expr) {
Expand Down Expand Up @@ -212,7 +217,8 @@ struct Fusion::ContainerMutator {
Expr* e = c->exprs_up_.back().get();
NVF_ERROR(
c->per_fusion_exprs_[self].count(e) > 0,
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
"removeStatementsCreatedAfter: tail expr belongs to another "
"Fusion");
for (Val* out : e->outputs()) {
out->setDefinition(nullptr);
}
Expand Down Expand Up @@ -457,10 +463,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

// Wire up definitions and uses on cloned vals in deterministic order
// to ensure exprs are inserted into exprs_up_ deterministically
// to ensure exprs are inserted into exprs_up_ deterministically.
// Skip reused vals (shared scalars) — their definition/uses belong to
// the source Fusion and must not be overwritten.
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
auto* cloned = ir_cloner.clone(val);
if (cloned == val) {
continue; // reused (shared scalar) — don't rewire
}
cloned->setDefinition(ir_cloner.clone(val->definition_));
cloned->setUses(ir_cloner.clone(val->uses_));
}

// Sync per-Fusion name counters from source to dest.
Expand Down
1 change: 1 addition & 0 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ class NVF_API Fusion : public PolymorphicBase {
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend class IrCloner;
friend Val;

//! Constructor that shares an existing container. Creates an empty Fusion
Expand Down
4 changes: 2 additions & 2 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,8 @@ std::pair<IrCloner, std::unique_ptr<Fusion>> SegmentedFusion::makeFusion(
SegmentedGroup* sg) const {
// TODO Optimize cloning step by only copying values and expressions between
// the fusion segment's inputs and outputs.
auto fusion_segment = std::unique_ptr<Fusion>(
new Fusion(completeFusion()->ir_container_ptr()));
auto fusion_segment =
std::unique_ptr<Fusion>(new Fusion(completeFusion()->ir_container_ptr()));

IrCloner complete_to_segment_map =
Fusion::copy(completeFusion(), fusion_segment.get());
Expand Down
19 changes: 19 additions & 0 deletions csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,25 @@ kir::Kernel* Statement::kernel() const {

NVFUSER_DEFINE_CLONE(Val)

bool Val::isOwnedBy(const Fusion* f) const {
return std::find(owning_fusions_.begin(), owning_fusions_.end(), f) !=
owning_fusions_.end();
}

void Val::addOwningFusion(Fusion* f) {
if (!isOwnedBy(f)) {
owning_fusions_.push_back(f);
}
}

bool Val::removeOwningFusion(Fusion* f) {
auto it = std::find(owning_fusions_.begin(), owning_fusions_.end(), f);
if (it != owning_fusions_.end()) {
owning_fusions_.erase(it);
}
return owning_fusions_.empty();
}

void Val::addDependency(Val* dependency) {
NVF_ERROR(dependency != nullptr);

Expand Down
21 changes: 21 additions & 0 deletions csrc/ir/base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,24 @@ class NVF_API Val : public Statement {
definition_ = expr;
}

// Multi-owner tracking for Phase 3 scalar sharing.
// owning_fusions_[0] = original creator (by convention, set at registration).
// Grows to 2+ only for shared scalars (IrCloner reuse path).
bool isShared() const {
return owning_fusions_.size() > 1;
}

bool isOwnedBy(const Fusion* f) const;

void addOwningFusion(Fusion* f);

// Remove an owning Fusion. Returns true if this was the last owner.
bool removeOwningFusion(Fusion* f);

const std::vector<Fusion*>& owningFusions() const {
return owning_fusions_;
}

NVFUSER_DECLARE_CLONE

protected:
Expand Down Expand Up @@ -454,6 +472,9 @@ class NVF_API Val : public Statement {
// welford operations.
DataType dtype_;

// Tracks all Fusions that own this Val. Seeded at registration.
std::vector<Fusion*> owning_fusions_;

// Following is managed by Fusion and can change.
bool is_fusion_input_ = false;
bool is_fusion_output_ = false;
Expand Down
47 changes: 37 additions & 10 deletions csrc/ir/cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,48 @@ Statement* IrCloner::clone(const Statement* statement) {
return nullptr;
}

// Have we already cloned this node?
// Step 1: Cache check — already cloned or reused
const auto it = clones_map_.find(statement);
if (it != clones_map_.end()) {
return it->second;
} else {
auto new_node = handle(statement);

// The base cloning constructor (Statement) should have
// registered the new node. Failure to do so indicates
// that something went horribly wrong.
NVF_ERROR(new_node != nullptr);
NVF_ERROR(clones_map_[statement] == new_node);
}

return new_node;
// Step 2: Scalar reuse — share instead of clone when src and dest Fusions
// use the same IrContainer. Only symbolic leaf scalars qualify (no
// definition, no concrete value). Constants and special vals are excluded
// by the !hasValue() check.
if (statement->isVal()) {
const Val* val = statement->as<Val>();
if (val->isScalar() && val->definition() == nullptr &&
!val->value().hasValue()) {
Fusion* src_fusion = val->container();
if (src_fusion != nullptr && src_fusion != ir_container_ &&
src_fusion->ir_container() == ir_container_->ir_container()) {
Val* reused = const_cast<Val*>(val);

// (a) Cache so downstream Expr clones resolve this input
clones_map_[statement] = reused;

// (b) Register with dest Fusion's per-Fusion tracking
auto* c = ir_container_->ir_container();
{
std::unique_lock lock(c->mutex_);
c->per_fusion_vals_[ir_container_].insert(reused);
}

// (c) Track ownership for lifetime management
reused->addOwningFusion(ir_container_);

return reused;
}
}
}

// Step 3: Full clone (unchanged)
auto new_node = handle(statement);
NVF_ERROR(new_node != nullptr);
NVF_ERROR(clones_map_[statement] == new_node);
return new_node;
}

void IrCloner::registerClone(const Statement* src, Statement* clone) {
Expand Down
10 changes: 8 additions & 2 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ int64_t IrContainer::numVals() const noexcept {
void IrContainer::addFusion(Fusion* fusion) {
std::unique_lock lock(mutex_);
sharing_fusions_.insert(fusion);
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs during concurrent val/expr registration
per_fusion_vals_[fusion]; // Pre-insert key so no outer-map rehash occurs
// during concurrent val/expr registration
per_fusion_exprs_[fusion];
}

Expand Down Expand Up @@ -225,8 +226,13 @@ void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) {
const auto& owned = vals_it->second;
std::erase_if(vals_up_, [&](const std::unique_ptr<Val>& v) {
if (owned.count(v.get()) > 0) {
// Multi-owner guard: only free if this is the last owning Fusion.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
if (!v->removeOwningFusion(const_cast<Fusion*>(fusion))) {
return false; // other Fusions still own this Val — keep alive
}
vals_.erase(v.get());
return true;
return true; // last owner gone → Val freed
}
return false;
});
Expand Down
1 change: 1 addition & 0 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class IrContainer {
protected:
// Let Fusion access IrContainer internals (mutex_, fields, Impl helpers)
friend class Fusion;
friend class IrCloner;

mutable std::shared_mutex mutex_;

Expand Down
13 changes: 11 additions & 2 deletions csrc/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,13 @@ void IterVisitor::traverseBetween(
if (to.empty()) {
return;
}
Fusion* fusion = to.front()->fusion();
// Use the active FusionGuard rather than deriving the Fusion from a val.
// This avoids calling fusion() on shared scalars whose ir_container_
// points to the original creator Fusion, not the current traversal target.
Fusion* fusion = FusionGuard::getCurFusion();
if (fusion == nullptr) {
fusion = to.front()->fusion();
}
FusionGuard fg(fusion);

std::unordered_set<Statement*> visited;
Expand Down Expand Up @@ -468,7 +474,10 @@ void BackwardVisitor::traverseTo(
if (from.empty()) {
return;
}
Fusion* fusion = from.front()->fusion();
Fusion* fusion = FusionGuard::getCurFusion();
if (fusion == nullptr) {
fusion = from.front()->fusion();
}
FusionGuard fg(fusion);

// Reset members
Expand Down
Loading