diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java index dacb7f97eab..7aaaaa6655e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchQueryManager.java @@ -16,6 +16,7 @@ import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.execution.AbstractPlan; +import org.opensearch.tasks.CancellableTask; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.client.node.NodeClient; @@ -33,15 +34,32 @@ public class OpenSearchQueryManager implements QueryManager { public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io"; + private static final ThreadLocal cancellableTask = new ThreadLocal<>(); + + public static void setCancellableTask(CancellableTask task) { + cancellableTask.set(task); + } + + public static CancellableTask getCancellableTask() { + return cancellableTask.get(); + } + + public static void clearCancellableTask() { + cancellableTask.remove(); + } + @Override public QueryId submit(AbstractPlan queryPlan) { TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT); - schedule(nodeClient, queryPlan::execute, timeout); + CancellableTask cancelTask = cancellableTask.get(); + cancellableTask.remove(); + schedule(nodeClient, queryPlan::execute, timeout, cancelTask); return queryPlan.getQueryId(); } - private void schedule(NodeClient client, Runnable task, TimeValue timeout) { + private void schedule( + NodeClient client, Runnable task, TimeValue timeout, CancellableTask cancelTask) { ThreadPool threadPool = client.threadPool(); Runnable wrappedTask = @@ -60,6 +78,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) { timeout, ThreadPool.Names.GENERIC); + setCancellableTask(cancelTask); + try { task.run(); timeoutTask.cancel(); @@ -76,6 +96,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) { } throw e; + } finally { + clearCancellableTask(); } }); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java index 6af9ad1e8d8..b0315e68eab 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java @@ -11,13 +11,16 @@ import lombok.EqualsAndHashCode; import lombok.ToString; import org.apache.calcite.linq4j.Enumerator; +import org.opensearch.core.tasks.TaskCancelledException; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.exception.NonFallbackCalciteException; import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.tasks.CancellableTask; /** * Supports a simple iteration over a collection for OpenSearch index @@ -55,6 +58,8 @@ public class OpenSearchIndexEnumerator implements Enumerator { private ExprValue current = null; + private CancellableTask cancellableTask; + public OpenSearchIndexEnumerator( OpenSearchClient client, List fields, @@ -80,6 +85,7 @@ public OpenSearchIndexEnumerator( this.client = client; this.bgScanner = new BackgroundSearchScanner(client, maxResultWindow, queryBucketSize); this.bgScanner.startScanning(request); + this.cancellableTask = OpenSearchQueryManager.getCancellableTask(); } private Iterator fetchNextBatch() { @@ -112,6 +118,10 @@ public boolean moveNext() { return false; } + if (cancellableTask != null && cancellableTask.isCancelled()) { + throw new TaskCancelledException("The task is cancelled."); + } + boolean shouldCheck = (queryCount % NUMBER_OF_NEXT_CALL_TO_CHECK == 0); if (shouldCheck) { org.opensearch.sql.monitor.ResourceStatus status = this.monitor.getStatus(); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java index 0d07dab966a..bb87bf7fa91 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java @@ -113,6 +113,11 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques if (pretty) { pplRequest.style(JsonResponseFormatter.Style.PRETTY); } + // set queryId + String queryId = jsonContent.optString("queryId", null); + if (queryId != null) { + pplRequest.queryId(queryId); + } return pplRequest; } catch (JSONException e) { throw new IllegalArgumentException("Failed to parse request payload", e); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java new file mode 100644 index 00000000000..2df96bdbd12 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/PPLQueryTask.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import java.util.Map; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; + +public class PPLQueryTask extends CancellableTask { + + public PPLQueryTask( + long id, + String type, + String action, + String description, + TaskId parentTaskId, + Map headers) { + super(id, type, action, description, parentTaskId, headers); + } + + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 48bc36374a8..fd932ef2fc2 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -31,6 +31,7 @@ import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.monitor.profile.QueryProfiling; +import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.sql.plugin.config.OpenSearchPluginModule; import org.opensearch.sql.ppl.PPLService; @@ -109,6 +110,9 @@ protected void doExecute( return; } + if (task instanceof PPLQueryTask pplQueryTask) { + OpenSearchQueryManager.setCancellableTask(pplQueryTask); + } Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment(); Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment(); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java index 6db2bd249ae..4ba1a53d872 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequest.java @@ -9,6 +9,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Locale; +import java.util.Map; import java.util.Optional; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -21,6 +22,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.tasks.TaskId; import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.sql.protocol.response.format.Format; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; @@ -51,6 +53,11 @@ public class TransportPPLQueryRequest extends ActionRequest { @Accessors(fluent = true) private boolean profile = false; + @Setter + @Getter + @Accessors(fluent = true) + private String queryId = null; + /** Constructor of TransportPPLQueryRequest from PPLQueryRequest. */ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) { pplQuery = pplQueryRequest.getRequest(); @@ -61,6 +68,7 @@ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) { style = pplQueryRequest.style(); profile = pplQueryRequest.profile(); explainMode = pplQueryRequest.mode().getModeName(); + queryId = pplQueryRequest.queryId(); } /** Constructor of TransportPPLQueryRequest from StreamInput. */ @@ -75,6 +83,7 @@ public TransportPPLQueryRequest(StreamInput in) throws IOException { sanitize = in.readBoolean(); style = in.readEnum(JsonResponseFormatter.Style.class); profile = in.readBoolean(); + queryId = in.readOptionalString(); } /** Re-create the object from the actionRequest. */ @@ -107,6 +116,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(sanitize); out.writeEnum(style); out.writeBoolean(profile); + out.writeOptionalString(queryId); } public String getRequest() { @@ -147,12 +157,25 @@ public ActionRequestValidationException validate() { return null; } + @Override + public PPLQueryTask createTask( + long id, String type, String action, TaskId parentTaskId, Map headers) { + return new PPLQueryTask(id, type, action, getDescription(), parentTaskId, headers); + } + + @Override + public String getDescription() { + String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: "; + return prefix + pplQuery; + } + /** Convert to PPLQueryRequest. */ public PPLQueryRequest toPPLQueryRequest() { PPLQueryRequest pplQueryRequest = new PPLQueryRequest(pplQuery, jsonContent, path, format, explainMode, profile); pplQueryRequest.sanitize(sanitize); pplQueryRequest.style(style); + pplQueryRequest.queryId(queryId); return pplQueryRequest; } } diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java new file mode 100644 index 00000000000..c9502ac3bbf --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/transport/PPLQueryTaskTest.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import static org.junit.Assert.*; + +import java.util.Map; +import org.junit.Test; +import org.opensearch.core.tasks.TaskId; + +public class PPLQueryTaskTest { + + @Test + public void testShouldCancelChildrenReturnsTrue() { + PPLQueryTask pplQueryTask = + new PPLQueryTask( + 1, + "transport", + "cluster:admin/opensearch/ppl", + "test query", + TaskId.EMPTY_TASK_ID, + Map.of()); + assertTrue(pplQueryTask.shouldCancelChildrenOnCancellation()); + } + + @Test + public void testCreateTaskReturnsPPLQueryTask() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + PPLQueryTask task = + transportPPLQueryRequest.createTask( + 1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); + assertNotNull(task); + } + + @Test + public void testWithQueryId() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + transportPPLQueryRequest.queryId("test-123"); + assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription()); + } + + @Test + public void testWithoutQueryId() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription()); + } + + @Test + public void testCooperativeModel() { + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl"); + PPLQueryTask task = + transportPPLQueryRequest.createTask( + 1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of()); + assertFalse(task.isCancelled()); + task.cancel("Test"); + assertTrue(task.isCancelled()); + } +} diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java b/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java index 4201c9cf6ab..06c7fe1c38e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/domain/PPLQueryRequest.java @@ -52,6 +52,11 @@ public class PPLQueryRequest { @Accessors(fluent = true) private boolean profile = false; + @Setter + @Getter + @Accessors(fluent = true) + private String queryId = null; + public PPLQueryRequest(String pplQuery, JSONObject jsonContent, String path) { this(pplQuery, jsonContent, path, ""); }