From ce18b425dd444d6484b45f5a9cbdd5242091e3e7 Mon Sep 17 00:00:00 2001 From: Maryam Shahid Date: Mon, 30 Mar 2026 14:47:35 -0700 Subject: [PATCH] early query cancellation based on per-segment sampling --- .../query/ChainedExecutionQueryRunner.java | 53 ++++- .../org/apache/druid/query/QueryContext.java | 5 + .../org/apache/druid/query/QueryContexts.java | 1 + ...ExecutionQueryRunnerExtrapolationTest.java | 181 ++++++++++++++++++ 4 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 processing/src/test/java/org/apache/druid/query/ChainedExecutionQueryRunnerExtrapolationTest.java diff --git a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java index 74f2ffc634a3..a4587b1d68ca 100644 --- a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java @@ -28,6 +28,7 @@ import org.apache.druid.common.guava.GuavaUtils; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.BaseSequence; import org.apache.druid.java.util.common.guava.MergeIterable; import org.apache.druid.java.util.common.guava.Sequence; @@ -40,6 +41,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; /** * A QueryRunner that combines a list of other QueryRunners and executes them in parallel on an executor. @@ -85,6 +88,13 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext respo final QueryContext context = query.context(); final boolean usePerSegmentTimeout = context.usePerSegmentTimeout(); final long perSegmentTimeout = context.getPerSegmentTimeout(); + final int samplingWindow = context.getPerSegmentSamplingWindow(); + final long queryStartNanos = System.nanoTime(); + + // Shared state for extrapolation — only allocated when sampling is enabled + final AtomicInteger completedSegments = samplingWindow > 0 ? new AtomicInteger(0) : null; + final AtomicBoolean extrapolationCancelled = samplingWindow > 0 ? new AtomicBoolean(false) : null; + return new BaseSequence<>( new BaseSequence.IteratorMaker<>() { @@ -120,6 +130,10 @@ public Iterable call() throw new ISE("Got a null list of results"); } + if (completedSegments != null) { + completedSegments.incrementAndGet(); + } + return retVal; } catch (QueryInterruptedException e) { @@ -153,8 +167,32 @@ public Iterable call() ) ); + final int totalSegments = futures.size(); ListenableFuture>> future = Futures.allAsList(futures); queryWatcher.registerQueryFuture(query, future); + + if (completedSegments != null && totalSegments >= samplingWindow && context.hasTimeout()) { + for (ListenableFuture f : futures) { + f.addListener( + () -> { + if (extrapolationCancelled.get()) { + return; + } + int completed = completedSegments.get(); + if (completed >= samplingWindow) { + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - queryStartNanos); + long extrapolatedMs = elapsedMs * totalSegments / completed; + long remainingMs = context.getTimeout() - elapsedMs; + if (extrapolatedMs > context.getTimeout() && remainingMs < extrapolatedMs - elapsedMs) { + extrapolationCancelled.set(true); + GuavaUtils.cancelAll(true, future, futures); + } + } + }, + Execs.directExecutor() + ); + } + } try { return new MergeIterable<>( @@ -165,8 +203,21 @@ public Iterable call() ).iterator(); } catch (CancellationException | InterruptedException e) { - log.noStackTrace().warn(e, "Query interrupted, cancelling pending results for query [%s]", query.getId()); GuavaUtils.cancelAll(true, future, futures); + if (extrapolationCancelled != null && extrapolationCancelled.get()) { + int completed = completedSegments.get(); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - queryStartNanos); + throw new QueryTimeoutException( + StringUtils.nonStrictFormat( + "Query [%s] cancelled: extrapolated wall-clock time exceeds timeout after %d of %d segments completed in %d ms.", + query.getId(), + completed, + totalSegments, + elapsedMs + ) + ); + } + log.noStackTrace().warn(e, "Query interrupted, cancelling pending results for query [%s]", query.getId()); throw new QueryInterruptedException(e); } catch (TimeoutException | QueryTimeoutException e) { diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index 42ed66978ffa..c190eac07705 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -570,6 +570,11 @@ public boolean usePerSegmentTimeout() return getPerSegmentTimeout() != QueryContexts.NO_TIMEOUT; } + public int getPerSegmentSamplingWindow() + { + return getInt(QueryContexts.PER_SEGMENT_SAMPLING_WINDOW_KEY, 0); + } + public void verifyMaxScatterGatherBytes(long maxScatterGatherBytesLimit) { long curr = getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 0); diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 6df392d95714..7115123b238d 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -45,6 +45,7 @@ public class QueryContexts public static final String LANE_KEY = "lane"; public static final String TIMEOUT_KEY = "timeout"; public static final String PER_SEGMENT_TIMEOUT_KEY = "perSegmentTimeout"; + public static final String PER_SEGMENT_SAMPLING_WINDOW_KEY = "perSegmentSamplingWindow"; public static final String MAX_SCATTER_GATHER_BYTES_KEY = "maxScatterGatherBytes"; public static final String MAX_QUEUED_BYTES_KEY = "maxQueuedBytes"; public static final String DEFAULT_TIMEOUT_KEY = "defaultTimeout"; diff --git a/processing/src/test/java/org/apache/druid/query/ChainedExecutionQueryRunnerExtrapolationTest.java b/processing/src/test/java/org/apache/druid/query/ChainedExecutionQueryRunnerExtrapolationTest.java new file mode 100644 index 000000000000..41fd0ac9b07d --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/ChainedExecutionQueryRunnerExtrapolationTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query; + +import com.google.common.base.Throwables; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.timeseries.TimeseriesQuery; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +@SuppressWarnings({"unchecked", "rawtypes"}) +public class ChainedExecutionQueryRunnerExtrapolationTest +{ + private QueryProcessingPool processingPool; + + @Before + public void setup() + { + processingPool = new ForwardingQueryProcessingPool( + Execs.multiThreaded(2, "ExtrapolationTestExecutor-%d"), + Execs.scheduledSingleThreaded("ExtrapolationTestExecutor-Timeout-%d") + ); + } + + @After + public void tearDown() + { + processingPool.shutdown(); + } + + @Test(timeout = 10_000L) + public void testExtrapolation_disabled_whenSamplingWindowZero() + { + QueryRunner slowRunner = sleepRunner(200); + QueryRunner fastRunner = (queryPlus, responseContext) -> Sequences.of(1); + + ChainedExecutionQueryRunner runner = makeRunner(slowRunner, fastRunner); + TimeseriesQuery query = makeQuery(Map.of( + QueryContexts.TIMEOUT_KEY, 10_000L + )); + + List results = runner.run(QueryPlus.wrap(query)).toList(); + Assert.assertNotNull(results); + Assert.assertEquals(2, results.size()); + } + + @Test(timeout = 10_000L) + public void testExtrapolation_cancelsQuery_whenProjectedTimeExceedsTimeout() + { + // 5 slow runners (300ms each on 2 threads), sampling window=2, timeout=500ms + // After 2 complete: elapsed ~300ms, extrapolated wall-clock = 300 * 5/2 = 750ms > 500ms → cancel + QueryRunner slowRunner = sleepRunner(300); + + ChainedExecutionQueryRunner runner = makeRunner( + slowRunner, + slowRunner, + slowRunner, + slowRunner, + slowRunner + ); + TimeseriesQuery query = makeQuery(Map.of( + QueryContexts.TIMEOUT_KEY, 500L, + QueryContexts.PER_SEGMENT_SAMPLING_WINDOW_KEY, 2 + )); + + Exception thrown = null; + try { + runner.run(QueryPlus.wrap(query)).toList(); + } + catch (Exception e) { + thrown = e; + } + + Assert.assertNotNull("Expected exception from extrapolation", thrown); + Assert.assertTrue( + "Should be QueryTimeoutException", + Throwables.getRootCause(thrown) instanceof QueryTimeoutException + ); + Assert.assertTrue( + "Message should mention extrapolation", + thrown.getMessage().contains("extrapolated") + ); + } + + @Test(timeout = 10_000L) + public void testExtrapolation_doesNotCancel_whenProjectedTimeWithinTimeout() + { + QueryRunner fastRunner = (queryPlus, responseContext) -> Sequences.of(1); + + ChainedExecutionQueryRunner runner = makeRunner( + fastRunner, + fastRunner, + fastRunner + ); + TimeseriesQuery query = makeQuery(Map.of( + QueryContexts.TIMEOUT_KEY, 30_000L, + QueryContexts.PER_SEGMENT_SAMPLING_WINDOW_KEY, 2 + )); + + List results = runner.run(QueryPlus.wrap(query)).toList(); + Assert.assertNotNull(results); + Assert.assertEquals(3, results.size()); + } + + @Test(timeout = 10_000L) + public void testExtrapolation_skipped_whenFewerSegmentsThanSamplingWindow() + { + QueryRunner slowRunner = sleepRunner(200); + + ChainedExecutionQueryRunner runner = makeRunner(slowRunner, slowRunner); + TimeseriesQuery query = makeQuery(Map.of( + QueryContexts.TIMEOUT_KEY, 10_000L, + QueryContexts.PER_SEGMENT_SAMPLING_WINDOW_KEY, 5 + )); + + List results = runner.run(QueryPlus.wrap(query)).toList(); + Assert.assertNotNull(results); + Assert.assertEquals(2, results.size()); + } + + private QueryRunner sleepRunner(long sleepMs) + { + return (queryPlus, responseContext) -> { + try { + Thread.sleep(sleepMs); + return Sequences.of(1); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }; + } + + private ChainedExecutionQueryRunner makeRunner(QueryRunner... runners) + { + QueryWatcher watcher = EasyMock.createNiceMock(QueryWatcher.class); + EasyMock.replay(watcher); + return new ChainedExecutionQueryRunner<>( + processingPool, + watcher, + Arrays.asList(runners) + ); + } + + private TimeseriesQuery makeQuery(Map context) + { + return Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2014/2015") + .aggregators(List.of(new CountAggregatorFactory("count"))) + .context(context) + .queryId("test") + .build(); + } +}