From 024a8cad4173d1c20dde863e9827376605c7f966 Mon Sep 17 00:00:00 2001 From: Sunil Ramchandra Pawar Date: Fri, 27 Feb 2026 17:48:35 +0530 Subject: [PATCH 1/2] Add query cancellation support via _tasks/_cancel API for PPL queries Signed-off-by: Sunil Ramchandra Pawar --- .../executor/OpenSearchQueryManager.java | 36 +++++++++- .../request/PPLQueryRequestFactory.java | 5 ++ .../sql/plugin/transport/SQLQueryTask.java | 44 +++++++++++++ .../transport/TransportPPLQueryAction.java | 20 ++++++ .../transport/TransportPPLQueryRequest.java | 29 ++++++++ .../plugin/transport/SQLQueryTaskTest.java | 66 +++++++++++++++++++ .../sql/ppl/domain/PPLQueryRequest.java | 5 ++ 7 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java create mode 100644 plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java index dacb7f97eab..c9793b9d2d0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.ThreadContext; +import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchTimeoutException; import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.setting.Settings; @@ -33,15 +34,33 @@ public class OpenSearchQueryManager implements QueryManager { public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io"; + public interface CancellationCallBack { + void onExecutionThreadAvailable(Thread thread); + void onExecutionComplete(); + boolean isCancelled(); + } + + public static ThreadLocal cancellationCallBackThreadLocal = new ThreadLocal<>(); + + public static void setCancellationCallback(CancellationCallBack value) { + cancellationCallBackThreadLocal.set(value); + } + + public static void clearCancellationCallback() { + cancellationCallBackThreadLocal.remove(); + } + @Override public QueryId submit(AbstractPlan queryPlan) { TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT); - schedule(nodeClient, queryPlan::execute, timeout); + CancellationCallBack callBack = cancellationCallBackThreadLocal.get(); + cancellationCallBackThreadLocal.remove(); + schedule(nodeClient, queryPlan::execute, timeout, callBack); return queryPlan.getQueryId(); } - private void schedule(NodeClient client, Runnable task, TimeValue timeout) { + private void schedule(NodeClient client, Runnable task, TimeValue timeout, CancellationCallBack callBack) { ThreadPool threadPool = client.threadPool(); Runnable wrappedTask = @@ -49,6 +68,10 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) { () -> { final Thread executionThread = Thread.currentThread(); + if (callBack != null) { + callBack.onExecutionThreadAvailable(executionThread); + } + Scheduler.ScheduledCancellable timeoutTask = threadPool.schedule( () -> { @@ -70,6 +93,10 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) { // Special-case handling of timeout-related interruptions if (Thread.interrupted() || e.getCause() instanceof InterruptedException) { + if (callBack != null && callBack.isCancelled()) { + LOG.info("Query was cancelled"); + throw new OpenSearchException("Query was cancelled."); + } LOG.error("Query was interrupted due to timeout after {}", timeout); throw new OpenSearchTimeoutException( "Query execution timed out after " + timeout); @@ -77,6 +104,11 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) { throw e; } + finally { + if (callBack != null) { + callBack.onExecutionComplete(); + } + } }); threadPool.schedule(wrappedTask, new TimeValue(0), SQL_WORKER_THREAD_POOL_NAME); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java index 0d07dab966a..44a032ceb9e 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java @@ -113,6 +113,11 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques if (pretty) { pplRequest.style(JsonResponseFormatter.Style.PRETTY); } + // set queryId + String queryId = jsonContent.optString("queryId", null); + if (queryId != null) { + pplRequest.queryId(queryId); + } return pplRequest; } catch (JSONException e) { throw new IllegalArgumentException("Failed to parse request payload", e); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java new file mode 100644 index 00000000000..aaa1c7d1990 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + + +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + + +public class SQLQueryTask extends CancellableTask { + + private final AtomicReference executionThread = new AtomicReference<>(); + + public SQLQueryTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { + super(id, type, action, description, parentTaskId, headers); + } + + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } + + public void setExecutionThread(Thread thread) { + executionThread.set(thread); + } + + public void clearExecutionThread() { + executionThread.set(null); + } + + @Override + public void onCancelled() { + Thread thread = executionThread.get(); + if (thread != null) { + thread.interrupt(); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 48bc36374a8..98278787011 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -31,6 +31,7 @@ import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.monitor.profile.QueryProfiling; +import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.plugin.config.OpenSearchPluginModule; import org.opensearch.sql.ppl.PPLService; @@ -109,6 +110,25 @@ protected void doExecute( return; } + if (task instanceof SQLQueryTask sqlQueryTask) { + + OpenSearchQueryManager.setCancellationCallback(new OpenSearchQueryManager.CancellationCallBack() { + @Override + public void onExecutionThreadAvailable(Thread thread) { + sqlQueryTask.setExecutionThread(thread); + } + + @Override + public void onExecutionComplete() { + sqlQueryTask.clearExecutionThread(); + } + + @Override + public boolean isCancelled() { + return sqlQueryTask.isCancelled(); + } + }); + } Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment(); Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment(); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java index 6db2bd249ae..4615eaa580f 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java @@ -9,6 +9,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Locale; +import java.util.Map; import java.util.Optional; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -21,9 +22,11 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.tasks.TaskId; import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.sql.protocol.response.format.Format; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.tasks.Task; @RequiredArgsConstructor public class TransportPPLQueryRequest extends ActionRequest { @@ -51,6 +54,11 @@ public class TransportPPLQueryRequest extends ActionRequest { @Accessors(fluent = true) private boolean profile = false; + @Setter + @Getter + @Accessors(fluent = true) + private String queryId = null; + /** Constructor of TransportPPLQueryRequest from PPLQueryRequest. */ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) { pplQuery = pplQueryRequest.getRequest(); @@ -61,6 +69,7 @@ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) { style = pplQueryRequest.style(); profile = pplQueryRequest.profile(); explainMode = pplQueryRequest.mode().getModeName(); + queryId = pplQueryRequest.queryId(); } /** Constructor of TransportPPLQueryRequest from StreamInput. */ @@ -75,6 +84,7 @@ public TransportPPLQueryRequest(StreamInput in) throws IOException { sanitize = in.readBoolean(); style = in.readEnum(JsonResponseFormatter.Style.class); profile = in.readBoolean(); + queryId = in.readOptionalString(); } /** Re-create the object from the actionRequest. */ @@ -107,6 +117,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(sanitize); out.writeEnum(style); out.writeBoolean(profile); + out.writeOptionalString(queryId); } public String getRequest() { @@ -147,12 +158,30 @@ public ActionRequestValidationException validate() { return null; } + @Override + public SQLQueryTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SQLQueryTask(id, type, action, getDescription() , parentTaskId, headers); + } + + @Override + public String getDescription() + { + String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: "; + + if (pplQuery != null && pplQuery.length() > 512) { + return prefix + pplQuery.substring(0,512) + "..."; + } + + return prefix + pplQuery; + } + /** Convert to PPLQueryRequest. */ public PPLQueryRequest toPPLQueryRequest() { PPLQueryRequest pplQueryRequest = new PPLQueryRequest(pplQuery, jsonContent, path, format, explainMode, profile); pplQueryRequest.sanitize(sanitize); pplQueryRequest.style(style); + pplQueryRequest.queryId(queryId); return pplQueryRequest; } } diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java new file mode 100644 index 00000000000..07a22bc7a4d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import org.junit.Test; +import org.opensearch.core.tasks.TaskId; + +import java.util.Map; + +import static org.junit.Assert.*; + +public class SQLQueryTaskTest { + + @Test + public void testOnCancelledInterruptsThread() { + SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); + sqlQueryTask.setExecutionThread(Thread.currentThread()); + sqlQueryTask.cancel("testing"); + assertTrue(Thread.currentThread().isInterrupted()); + Thread.interrupted(); + } + + @Test + public void testOnCancelledWithNoThread() { + SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); + sqlQueryTask.cancel("testing"); + } + + @Test + public void testClearExecutionThread() { + SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); + sqlQueryTask.setExecutionThread(Thread.currentThread()); + sqlQueryTask.clearExecutionThread(); + sqlQueryTask.cancel("testing"); + assertFalse(Thread.currentThread().isInterrupted()); + } + + @Test + public void testShouldCancelChildrenReturnsFalse() { + SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); + assertFalse(sqlQueryTask.shouldCancelChildrenOnCancellation()); + } + + @Test + public void testCreateTaskReturnsSQLQueryTask() { + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + SQLQueryTask task = transportPPLQueryRequest.createTask(1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); + assertNotNull(task); + } + + @Test + public void testWithQueryId () { + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + transportPPLQueryRequest.queryId("test-123"); + assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription()); + } + + @Test + public void testWithoutQueryId () { + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription()); + } +} diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java b/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java index 4201c9cf6ab..06c7fe1c38e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java @@ -52,6 +52,11 @@ public class PPLQueryRequest { @Accessors(fluent = true) private boolean profile = false; + @Setter + @Getter + @Accessors(fluent = true) + private String queryId = null; + public PPLQueryRequest(String pplQuery, JSONObject jsonContent, String path) { this(pplQuery, jsonContent, path, ""); } From b2f0dc0f41b8c8b3d47e5b6ef91586b4d61495ab Mon Sep 17 00:00:00 2001 From: Sunil Ramchandra Pawar Date: Tue, 24 Mar 2026 17:04:21 +0530 Subject: [PATCH 2/2] Refactor PPL query cancellation to cooperative model and other PR suggestions. Signed-off-by: Sunil Ramchandra Pawar --- .../executor/OpenSearchQueryManager.java | 46 +++++-------- .../scan/OpenSearchIndexEnumerator.java | 10 +++ .../request/PPLQueryRequestFactory.java | 2 +- .../sql/plugin/transport/PPLQueryTask.java | 28 ++++++++ .../sql/plugin/transport/SQLQueryTask.java | 44 ------------- .../transport/TransportPPLQueryAction.java | 20 +----- .../transport/TransportPPLQueryRequest.java | 18 ++--- .../plugin/transport/PPLQueryTaskTest.java | 65 ++++++++++++++++++ .../plugin/transport/SQLQueryTaskTest.java | 66 ------------------- 9 files changed, 130 insertions(+), 169 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java delete mode 100644 plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java create mode 100644 plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java delete mode 100644 plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java index c9793b9d2d0..7aaaaa6655e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java @@ -10,13 +10,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.ThreadContext; -import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchTimeoutException; import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.execution.AbstractPlan; +import org.opensearch.tasks.CancellableTask; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.node.NodeClient; @@ -34,33 +34,32 @@ public class OpenSearchQueryManager implements QueryManager { public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io"; - public interface CancellationCallBack { - void onExecutionThreadAvailable(Thread thread); - void onExecutionComplete(); - boolean isCancelled(); - } + private static final ThreadLocal cancellableTask = new ThreadLocal<>(); - public static ThreadLocal cancellationCallBackThreadLocal = new ThreadLocal<>(); + public static void setCancellableTask(CancellableTask task) { + cancellableTask.set(task); + } - public static void setCancellationCallback(CancellationCallBack value) { - cancellationCallBackThreadLocal.set(value); + public static CancellableTask getCancellableTask() { + return cancellableTask.get(); } - public static void clearCancellationCallback() { - cancellationCallBackThreadLocal.remove(); + public static void clearCancellableTask() { + cancellableTask.remove(); } @Override public QueryId submit(AbstractPlan queryPlan) { TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT); - CancellationCallBack callBack = cancellationCallBackThreadLocal.get(); - cancellationCallBackThreadLocal.remove(); - schedule(nodeClient, queryPlan::execute, timeout, callBack); + CancellableTask cancelTask = cancellableTask.get(); + cancellableTask.remove(); + schedule(nodeClient, queryPlan::execute, timeout, cancelTask); return queryPlan.getQueryId(); } - private void schedule(NodeClient client, Runnable task, TimeValue timeout, CancellationCallBack callBack) { + private void schedule( + NodeClient client, Runnable task, TimeValue timeout, CancellableTask cancelTask) { ThreadPool threadPool = client.threadPool(); Runnable wrappedTask = @@ -68,10 +67,6 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout, Cance () -> { final Thread executionThread = Thread.currentThread(); - if (callBack != null) { - callBack.onExecutionThreadAvailable(executionThread); - } - Scheduler.ScheduledCancellable timeoutTask = threadPool.schedule( () -> { @@ -83,6 +78,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout, Cance timeout, ThreadPool.Names.GENERIC); + setCancellableTask(cancelTask); + try { task.run(); timeoutTask.cancel(); @@ -93,21 +90,14 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout, Cance // Special-case handling of timeout-related interruptions if (Thread.interrupted() || e.getCause() instanceof InterruptedException) { - if (callBack != null && callBack.isCancelled()) { - LOG.info("Query was cancelled"); - throw new OpenSearchException("Query was cancelled."); - } LOG.error("Query was interrupted due to timeout after {}", timeout); throw new OpenSearchTimeoutException( "Query execution timed out after " + timeout); } throw e; - } - finally { - if (callBack != null) { - callBack.onExecutionComplete(); - } + } finally { + clearCancellableTask(); } }); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java index 6af9ad1e8d8..b0315e68eab 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java @@ -11,13 +11,16 @@ import lombok.EqualsAndHashCode; import lombok.ToString; import org.apache.calcite.linq4j.Enumerator; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.exception.NonFallbackCalciteException; import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.tasks.CancellableTask; /** * Supports a simple iteration over a collection for OpenSearch index @@ -55,6 +58,8 @@ public class OpenSearchIndexEnumerator implements Enumerator { private ExprValue current = null; + private CancellableTask cancellableTask; + public OpenSearchIndexEnumerator( OpenSearchClient client, List fields, @@ -80,6 +85,7 @@ public OpenSearchIndexEnumerator( this.client = client; this.bgScanner = new BackgroundSearchScanner(client, maxResultWindow, queryBucketSize); this.bgScanner.startScanning(request); + this.cancellableTask = OpenSearchQueryManager.getCancellableTask(); } private Iterator fetchNextBatch() { @@ -112,6 +118,10 @@ public boolean moveNext() { return false; } + if (cancellableTask != null && cancellableTask.isCancelled()) { + throw new TaskCancelledException("The task is cancelled."); + } + boolean shouldCheck = (queryCount % NUMBER_OF_NEXT_CALL_TO_CHECK == 0); if (shouldCheck) { org.opensearch.sql.monitor.ResourceStatus status = this.monitor.getStatus(); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java index 44a032ceb9e..bb87bf7fa91 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java @@ -116,7 +116,7 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques // set queryId String queryId = jsonContent.optString("queryId", null); if (queryId != null) { - pplRequest.queryId(queryId); + pplRequest.queryId(queryId); } return pplRequest; } catch (JSONException e) { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java new file mode 100644 index 00000000000..2df96bdbd12 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import java.util.Map; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; + +public class PPLQueryTask extends CancellableTask { + + public PPLQueryTask( + long id, + String type, + String action, + String description, + TaskId parentTaskId, + Map headers) { + super(id, type, action, description, parentTaskId, headers); + } + + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java deleted file mode 100644 index aaa1c7d1990..00000000000 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/SQLQueryTask.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.plugin.transport; - - -import org.opensearch.core.tasks.TaskId; -import org.opensearch.tasks.CancellableTask; - -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; - - -public class SQLQueryTask extends CancellableTask { - - private final AtomicReference executionThread = new AtomicReference<>(); - - public SQLQueryTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { - super(id, type, action, description, parentTaskId, headers); - } - - @Override - public boolean shouldCancelChildrenOnCancellation() { - return false; - } - - public void setExecutionThread(Thread thread) { - executionThread.set(thread); - } - - public void clearExecutionThread() { - executionThread.set(null); - } - - @Override - public void onCancelled() { - Thread thread = executionThread.get(); - if (thread != null) { - thread.interrupt(); - } - } -} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 98278787011..fd932ef2fc2 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -110,24 +110,8 @@ protected void doExecute( return; } - if (task instanceof SQLQueryTask sqlQueryTask) { - - OpenSearchQueryManager.setCancellationCallback(new OpenSearchQueryManager.CancellationCallBack() { - @Override - public void onExecutionThreadAvailable(Thread thread) { - sqlQueryTask.setExecutionThread(thread); - } - - @Override - public void onExecutionComplete() { - sqlQueryTask.clearExecutionThread(); - } - - @Override - public boolean isCancelled() { - return sqlQueryTask.isCancelled(); - } - }); + if (task instanceof PPLQueryTask pplQueryTask) { + OpenSearchQueryManager.setCancellableTask(pplQueryTask); } Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment(); Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment(); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java index 4615eaa580f..4ba1a53d872 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java @@ -26,7 +26,6 @@ import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.sql.protocol.response.format.Format; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; -import org.opensearch.tasks.Task; @RequiredArgsConstructor public class TransportPPLQueryRequest extends ActionRequest { @@ -159,20 +158,15 @@ public ActionRequestValidationException validate() { } @Override - public SQLQueryTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new SQLQueryTask(id, type, action, getDescription() , parentTaskId, headers); + public PPLQueryTask createTask( + long id, String type, String action, TaskId parentTaskId, Map headers) { + return new PPLQueryTask(id, type, action, getDescription(), parentTaskId, headers); } @Override - public String getDescription() - { - String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: "; - - if (pplQuery != null && pplQuery.length() > 512) { - return prefix + pplQuery.substring(0,512) + "..."; - } - - return prefix + pplQuery; + public String getDescription() { + String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: "; + return prefix + pplQuery; } /** Convert to PPLQueryRequest. */ diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java new file mode 100644 index 00000000000..c9502ac3bbf --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import static org.junit.Assert.*; + +import java.util.Map; +import org.junit.Test; +import org.opensearch.core.tasks.TaskId; + +public class PPLQueryTaskTest { + + @Test + public void testShouldCancelChildrenReturnsTrue() { + PPLQueryTask pplQueryTask = + new PPLQueryTask( + 1, + "transport", + "cluster:admin/opensearch/ppl", + "test query", + TaskId.EMPTY_TASK_ID, + Map.of()); + assertTrue(pplQueryTask.shouldCancelChildrenOnCancellation()); + } + + @Test + public void testCreateTaskReturnsPPLQueryTask() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + PPLQueryTask task = + transportPPLQueryRequest.createTask( + 1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); + assertNotNull(task); + } + + @Test + public void testWithQueryId() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + transportPPLQueryRequest.queryId("test-123"); + assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription()); + } + + @Test + public void testWithoutQueryId() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription()); + } + + @Test + public void testCooperativeModel() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + PPLQueryTask task = + transportPPLQueryRequest.createTask( + 1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); + assertFalse(task.isCancelled()); + task.cancel("Test"); + assertTrue(task.isCancelled()); + } +} diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java deleted file mode 100644 index 07a22bc7a4d..00000000000 --- a/plugin/src/test/java/org/opensearch/sql/plugin/transport/SQLQueryTaskTest.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.plugin.transport; - -import org.junit.Test; -import org.opensearch.core.tasks.TaskId; - -import java.util.Map; - -import static org.junit.Assert.*; - -public class SQLQueryTaskTest { - - @Test - public void testOnCancelledInterruptsThread() { - SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); - sqlQueryTask.setExecutionThread(Thread.currentThread()); - sqlQueryTask.cancel("testing"); - assertTrue(Thread.currentThread().isInterrupted()); - Thread.interrupted(); - } - - @Test - public void testOnCancelledWithNoThread() { - SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); - sqlQueryTask.cancel("testing"); - } - - @Test - public void testClearExecutionThread() { - SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); - sqlQueryTask.setExecutionThread(Thread.currentThread()); - sqlQueryTask.clearExecutionThread(); - sqlQueryTask.cancel("testing"); - assertFalse(Thread.currentThread().isInterrupted()); - } - - @Test - public void testShouldCancelChildrenReturnsFalse() { - SQLQueryTask sqlQueryTask = new SQLQueryTask(1, "transport", "cluster:admin/opensearch/ppl", "test query", TaskId.EMPTY_TASK_ID, Map.of()); - assertFalse(sqlQueryTask.shouldCancelChildrenOnCancellation()); - } - - @Test - public void testCreateTaskReturnsSQLQueryTask() { - TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); - SQLQueryTask task = transportPPLQueryRequest.createTask(1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); - assertNotNull(task); - } - - @Test - public void testWithQueryId () { - TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); - transportPPLQueryRequest.queryId("test-123"); - assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription()); - } - - @Test - public void testWithoutQueryId () { - TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); - assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription()); - } -}