Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/common/base/src/base/barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ impl Barrier {
BarrierWaitResult(is_leader)
}

pub fn reduce_quorum(&self, n: usize) {
/// Reduces the quorum by `n`. Returns `true` if the quorum reached zero
/// purely through reductions (no thread called `wait()`), meaning all
/// participants were removed without ever synchronizing.
pub fn reduce_quorum(&self, n: usize) -> bool {
let locked = self.state.lock();
let mut state = locked.unwrap_or_else(PoisonError::into_inner);
state.n -= n;
Expand All @@ -95,9 +98,12 @@ impl Barrier {
.waker
.send(state.generation)
.expect("there is at least one receiver");
let all_reduced = state.arrived == 0;
state.arrived = 0;
state.generation += 1;
return all_reduced;
}
false
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ impl RuntimeFiltersDesc {
}))
}

/// Close the broadcast source channel and notify runtime filter watchers.
/// Called when all threads of a hash join are short-circuited (e.g., downstream
/// LIMIT satisfied via sequential UNION ALL) and no thread will call `globalization`.
pub fn close_broadcast(&self) {
if let Some(broadcast_id) = self.broadcast_id {
self.ctx.broadcast_source_sender(broadcast_id).close();
}
for ready in &self.runtime_filters_ready {
let _ = ready.runtime_filter_watcher.send(None);
}
}

pub async fn globalization(&self, mut packet: JoinRuntimeFilterPacket) -> Result<()> {
if let Some(broadcast_id) = self.broadcast_id {
packet = get_global_runtime_filter_packet(broadcast_id, packet, &self.ctx).await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ impl Processor for TransformHashJoin {
std::mem::swap(&mut finished, &mut self.join);
drop(finished);

self.stage_sync_barrier.reduce_quorum(1);
if self.stage_sync_barrier.reduce_quorum(1) {
self.rf_desc.close_broadcast();
}
}

return Ok(Event::Finished);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
statement ok
drop table if exists t_build

statement ok
drop table if exists t_probe

statement ok
create table t_build(a int not null, b string not null)

statement ok
create table t_probe(a int not null, c string not null)

statement ok
insert into t_build select number, concat('build_', to_string(number)) from numbers(1000)

statement ok
insert into t_probe select number, concat('probe_', to_string(number)) from numbers(1000)

statement ok
set enforce_shuffle_join = 1

statement ok
set enable_parallel_union_all = 0

query I
select count(*) from (
select a, b from (
select t_probe.a, t_build.b from t_probe inner join t_build on t_probe.a = t_build.a
union all
select t_probe.a, t_build.b from t_probe inner join t_build on t_probe.a = t_build.a
union all
select t_probe.a, t_build.b from t_probe inner join t_build on t_probe.a = t_build.a
union all
select t_probe.a, t_build.b from t_probe inner join t_build on t_probe.a = t_build.a
) t
limit 5
)
----
5

statement ok
unset enforce_shuffle_join

statement ok
unset enable_parallel_union_all

statement ok
drop table if exists t_build

statement ok
drop table if exists t_probe
Loading