From 4278925227db3f7575b1f56c4dbc0ff4511c77f4 Mon Sep 17 00:00:00 2001 From: Jonathan Jaegerman Date: Fri, 6 Mar 2026 13:46:59 -0800 Subject: [PATCH 1/2] mark task slot used in async workflow poll task --- .../io/temporal/internal/worker/AsyncWorkflowPollTask.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java index 06ad0e323..4a7f793cc 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java @@ -30,7 +30,7 @@ public class AsyncWorkflowPollTask implements AsyncPoller.PollTaskAsync, DisableNormalPolling { private static final Logger log = LoggerFactory.getLogger(AsyncWorkflowPollTask.class); - private final TrackingSlotSupplier slotSupplier; + private final TrackingSlotSupplier slotSupplier; private final WorkflowServiceStubs service; private final Scope metricsScope; private final Scope pollerMetricScope; @@ -150,6 +150,7 @@ public CompletableFuture poll(SlotPermit permit) .inc(1); return null; } + slotSupplier.markSlotUsed(new WorkflowSlotInfo(r, pollRequest), permit); pollerMetricScope .counter(MetricsType.WORKFLOW_TASK_QUEUE_POLL_SUCCEED_COUNTER) .inc(1); From 5c8fce648b9b62aa6a352bee53b844f86a462161 Mon Sep 17 00:00:00 2001 From: Jonathan Jaegerman Date: Fri, 6 Mar 2026 14:39:45 -0800 Subject: [PATCH 2/2] slot supplier test with async poll task --- .../internal/worker/SlotSupplierTest.java | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java index 9fe9dbc8b..ee2525547 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java @@ -5,6 +5,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import com.google.common.util.concurrent.Futures; import com.google.protobuf.ByteString; import com.uber.m3.tally.RootScopeBuilder; import com.uber.m3.tally.Scope; @@ -17,6 +18,8 @@ import io.temporal.serviceclient.WorkflowServiceStubs; import io.temporal.worker.tuning.*; import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -124,4 +127,64 @@ public void supplierIsCalledAppropriately() { assertEquals(1, trackingSS.getUsedSlots().size()); } } + + @Test + public void asyncPollerSupplierIsCalledAppropriately() throws Exception { + WorkflowServiceStubs client = mock(WorkflowServiceStubs.class); + when(client.getServerCapabilities()) + .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); + + WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = + mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); + when(client.futureStub()).thenReturn(futureStub); + when(futureStub.withOption(any(), any())).thenReturn(futureStub); + + SlotSupplier mockSupplier = mock(SlotSupplier.class); + Scope metricsScope = + new RootScopeBuilder() + .reporter(reporter) + .reportEvery(com.uber.m3.util.Duration.ofMillis(1)); + TrackingSlotSupplier trackingSS = + new TrackingSlotSupplier<>(mockSupplier, metricsScope); + + PollWorkflowTaskQueueResponse pollResponse = + PollWorkflowTaskQueueResponse.newBuilder() + .setTaskToken(ByteString.copyFrom("token", UTF_8)) + .setWorkflowExecution( + WorkflowExecution.newBuilder().setWorkflowId(WORKFLOW_ID).setRunId(RUN_ID).build()) + .setWorkflowType(WorkflowType.newBuilder().setName(WORKFLOW_TYPE).build()) + .build(); + + if (throwOnPoll) { + when(futureStub.pollWorkflowTaskQueue(any())) + .thenReturn(Futures.immediateFailedFuture(new RuntimeException("Poll failed"))); + } else { + when(futureStub.pollWorkflowTaskQueue(any())) + .thenReturn(Futures.immediateFuture(pollResponse)); + } + + AsyncWorkflowPollTask pollTask = + new AsyncWorkflowPollTask( + client, + "default", + TASK_QUEUE, + null, + "", + new WorkerVersioningOptions("", false, null), + trackingSS, + metricsScope, + () -> GetSystemInfoResponse.Capabilities.newBuilder().build()); + + SlotPermit permit = new SlotPermit(); + + CompletableFuture future = pollTask.poll(permit); + if (throwOnPoll) { + assertThrows(ExecutionException.class, future::get); + assertEquals(0, trackingSS.getUsedSlots().size()); + } else { + WorkflowTask task = future.get(); + assertNotNull(task); + assertEquals(1, trackingSS.getUsedSlots().size()); + } + } }