Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.sql.executor.QueryId;
import org.opensearch.sql.executor.QueryManager;
import org.opensearch.sql.executor.execution.AbstractPlan;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.node.NodeClient;
Expand All @@ -33,15 +34,32 @@ public class OpenSearchQueryManager implements QueryManager {
public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker";
public static final String SQL_BACKGROUND_THREAD_POOL_NAME = "sql_background_io";

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could be private static since the accessor methods already exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These consts are directly referenced in a few places iirc, I originally refactored to this from a lot of places duplicating a "sql-worker" string.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't referring to the thread pool name constants — those should stay public static. I meant the cancellationCallBackThreadLocal field added in this PR (line 45). It's declared public static but has setCancellationCallback/clearCancellationCallback accessor methods, so the field itself could be private static to prevent direct access.

private static final ThreadLocal<CancellableTask> cancellableTask = new ThreadLocal<>();

public static void setCancellableTask(CancellableTask task) {
cancellableTask.set(task);
}

public static CancellableTask getCancellableTask() {
return cancellableTask.get();
}

public static void clearCancellableTask() {
cancellableTask.remove();
}

@Override
public QueryId submit(AbstractPlan queryPlan) {
TimeValue timeout = settings.getSettingValue(Settings.Key.PPL_QUERY_TIMEOUT);
schedule(nodeClient, queryPlan::execute, timeout);
CancellableTask cancelTask = cancellableTask.get();
cancellableTask.remove();
schedule(nodeClient, queryPlan::execute, timeout, cancelTask);

return queryPlan.getQueryId();
}

private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
private void schedule(
NodeClient client, Runnable task, TimeValue timeout, CancellableTask cancelTask) {
ThreadPool threadPool = client.threadPool();

Runnable wrappedTask =
Expand All @@ -60,6 +78,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
timeout,
ThreadPool.Names.GENERIC);

setCancellableTask(cancelTask);

try {
task.run();
timeoutTask.cancel();
Expand All @@ -76,6 +96,8 @@ private void schedule(NodeClient client, Runnable task, TimeValue timeout) {
}

throw e;
} finally {
clearCancellableTask();
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.apache.calcite.linq4j.Enumerator;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.exception.NonFallbackCalciteException;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.monitor.ResourceMonitor;
import org.opensearch.sql.opensearch.client.OpenSearchClient;
import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager;
import org.opensearch.sql.opensearch.request.OpenSearchRequest;
import org.opensearch.tasks.CancellableTask;

/**
* Supports a simple iteration over a collection for OpenSearch index
Expand Down Expand Up @@ -55,6 +58,8 @@ public class OpenSearchIndexEnumerator implements Enumerator<Object> {

private ExprValue current = null;

private CancellableTask cancellableTask;

public OpenSearchIndexEnumerator(
OpenSearchClient client,
List<String> fields,
Expand All @@ -80,6 +85,7 @@ public OpenSearchIndexEnumerator(
this.client = client;
this.bgScanner = new BackgroundSearchScanner(client, maxResultWindow, queryBucketSize);
this.bgScanner.startScanning(request);
this.cancellableTask = OpenSearchQueryManager.getCancellableTask();
}

private Iterator<ExprValue> fetchNextBatch() {
Expand Down Expand Up @@ -112,6 +118,10 @@ public boolean moveNext() {
return false;
}

if (cancellableTask != null && cancellableTask.isCancelled()) {
throw new TaskCancelledException("The task is cancelled.");
}

boolean shouldCheck = (queryCount % NUMBER_OF_NEXT_CALL_TO_CHECK == 0);
if (shouldCheck) {
org.opensearch.sql.monitor.ResourceStatus status = this.monitor.getStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques
if (pretty) {
pplRequest.style(JsonResponseFormatter.Style.PRETTY);
}
// set queryId
String queryId = jsonContent.optString("queryId", null);
if (queryId != null) {
pplRequest.queryId(queryId);
}
return pplRequest;
} catch (JSONException e) {
throw new IllegalArgumentException("Failed to parse request payload", e);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.plugin.transport;

import java.util.Map;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.CancellableTask;

public class PPLQueryTask extends CancellableTask {

public PPLQueryTask(
long id,
String type,
String action,
String description,
TaskId parentTaskId,
Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
}

@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.monitor.profile.QueryProfiling;
import org.opensearch.sql.opensearch.executor.OpenSearchQueryManager;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.plugin.config.OpenSearchPluginModule;
import org.opensearch.sql.ppl.PPLService;
Expand Down Expand Up @@ -109,6 +110,9 @@ protected void doExecute(
return;
}

if (task instanceof PPLQueryTask pplQueryTask) {
OpenSearchQueryManager.setCancellableTask(pplQueryTask);
}
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_TOTAL).increment();
Metrics.getInstance().getNumericalMetric(MetricName.PPL_REQ_COUNT_TOTAL).increment();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.sql.ppl.domain.PPLQueryRequest;
import org.opensearch.sql.protocol.response.format.Format;
import org.opensearch.sql.protocol.response.format.JsonResponseFormatter;
Expand Down Expand Up @@ -51,6 +53,11 @@ public class TransportPPLQueryRequest extends ActionRequest {
@Accessors(fluent = true)
private boolean profile = false;

@Setter
@Getter
@Accessors(fluent = true)
private String queryId = null;

/** Constructor of TransportPPLQueryRequest from PPLQueryRequest. */
public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) {
pplQuery = pplQueryRequest.getRequest();
Expand All @@ -61,6 +68,7 @@ public TransportPPLQueryRequest(PPLQueryRequest pplQueryRequest) {
style = pplQueryRequest.style();
profile = pplQueryRequest.profile();
explainMode = pplQueryRequest.mode().getModeName();
queryId = pplQueryRequest.queryId();
}

/** Constructor of TransportPPLQueryRequest from StreamInput. */
Expand All @@ -75,6 +83,7 @@ public TransportPPLQueryRequest(StreamInput in) throws IOException {
sanitize = in.readBoolean();
style = in.readEnum(JsonResponseFormatter.Style.class);
profile = in.readBoolean();
queryId = in.readOptionalString();
}

/** Re-create the object from the actionRequest. */
Expand Down Expand Up @@ -107,6 +116,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(sanitize);
out.writeEnum(style);
out.writeBoolean(profile);
out.writeOptionalString(queryId);
}

public String getRequest() {
Expand Down Expand Up @@ -147,12 +157,25 @@ public ActionRequestValidationException validate() {
return null;
}

@Override
public PPLQueryTask createTask(
long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new PPLQueryTask(id, type, action, getDescription(), parentTaskId, headers);
}

@Override
public String getDescription() {
String prefix = (queryId != null) ? "PPL [queryId=" + queryId + "]: " : "PPL: ";
return prefix + pplQuery;
}

/** Convert to PPLQueryRequest. */
public PPLQueryRequest toPPLQueryRequest() {
PPLQueryRequest pplQueryRequest =
new PPLQueryRequest(pplQuery, jsonContent, path, format, explainMode, profile);
pplQueryRequest.sanitize(sanitize);
pplQueryRequest.style(style);
pplQueryRequest.queryId(queryId);
return pplQueryRequest;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.plugin.transport;

import static org.junit.Assert.*;

import java.util.Map;
import org.junit.Test;
import org.opensearch.core.tasks.TaskId;

public class PPLQueryTaskTest {

@Test
public void testShouldCancelChildrenReturnsTrue() {
PPLQueryTask pplQueryTask =
new PPLQueryTask(
1,
"transport",
"cluster:admin/opensearch/ppl",
"test query",
TaskId.EMPTY_TASK_ID,
Map.of());
assertTrue(pplQueryTask.shouldCancelChildrenOnCancellation());
}

@Test
public void testCreateTaskReturnsPPLQueryTask() {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
PPLQueryTask task =
transportPPLQueryRequest.createTask(
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
assertNotNull(task);
}

@Test
public void testWithQueryId() {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
transportPPLQueryRequest.queryId("test-123");
assertEquals("PPL [queryId=test-123]: source=t a=1", transportPPLQueryRequest.getDescription());
}

@Test
public void testWithoutQueryId() {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
assertEquals("PPL: source=t a=1", transportPPLQueryRequest.getDescription());
}

@Test
public void testCooperativeModel() {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest("source=t a=1", null, "/_plugins/_ppl");
PPLQueryTask task =
transportPPLQueryRequest.createTask(
1, "transport", "cluster:admin/opensearch/ppl", TaskId.EMPTY_TASK_ID, Map.of());
assertFalse(task.isCancelled());
task.cancel("Test");
assertTrue(task.isCancelled());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public class PPLQueryRequest {
@Accessors(fluent = true)
private boolean profile = false;

@Setter
@Getter
@Accessors(fluent = true)
private String queryId = null;

public PPLQueryRequest(String pplQuery, JSONObject jsonContent, String path) {
this(pplQuery, jsonContent, path, "");
}
Expand Down
Loading