diff --git a/demos/multi_guess_server.py b/demos/multi_guess_server.py index 9c504aa..29dac27 100644 --- a/demos/multi_guess_server.py +++ b/demos/multi_guess_server.py @@ -5,6 +5,7 @@ from quest import (step, queue, state, identity_queue, create_filesystem_manager, these) from quest.server import Server +from quest.external import MultiQueue @step @@ -35,7 +36,6 @@ async def get_secret(): @step async def get_guesses(players: dict[str, str], message) -> dict[str, int]: guesses = {} - status_message = [] # TODO - the following code sequence is a little verbose @@ -47,19 +47,9 @@ async def get_guesses(players: dict[str, str], message) -> dict[str, int]: # This pattern should be common enough we should make # it easy and clear - async with ( - # Create a guess queue for each player - these({ - ident: queue('guess', ident) - for ident in players - }) as guess_queues - ): - # Wait for guesses to come in. - # As they do, remove their queue so they can't guess again. - guess_gets = {q.get(): ident for ident, q in guess_queues.items()} - for guess_get in asyncio.as_completed(guess_gets): - guess = await guess_get - ident = guess_gets[guess] + # Iterate guesses one at a time + async with MultiQueue('guess', players, single_response=True) as mq: + async for ident, guess in mq: guesses[ident] = guess # Update the status @@ -67,10 +57,6 @@ async def get_guesses(players: dict[str, str], message) -> dict[str, int]: status_message.append(f'{name} guessed {guess}') message.set('\n'.join(status_message)) - # Remove the queue - # The user will no longer see it - guess_queues.remove(ident) - return guesses diff --git a/quest_test/test_external_actions.py b/quest_test/test_external_actions.py index 59247cd..89fb2bb 100644 --- a/quest_test/test_external_actions.py +++ b/quest_test/test_external_actions.py @@ -2,7 +2,7 @@ import pytest -from quest.external import state, queue, event, wrap_as_state, wrap_as_queue +from quest.external import state, queue, event, wrap_as_state, wrap_as_queue, MultiQueue from quest.historian import Historian from quest.wrappers import task, step from quest.serializer import NoopSerializer @@ -305,6 +305,68 @@ async def test_step_specific_external(): assert (await workflow) == 3 +@pytest.mark.asyncio +@timeout(3) +async def test_multiqueue_default(): + received = [] + + async def player_workflow(): + players = {'p1': 'user1', 'p2': 'user2'} + + async with MultiQueue('chat', players) as mq: + async for ident, msg in mq: + received.append((ident, msg)) + # If player sends 'bye', remove their queue after their message is recorded + if msg == 'bye': + await mq.remove(ident) + # Exit the Multiqueue when 3 messages are recorded + if len(received) == 3: + break + return received + + historian = Historian('test', player_workflow, [], serializer=NoopSerializer()) + workflow = historian.run() + + await asyncio.sleep(0.1) + + await historian.record_external_event('chat', 'p1', 'put', 'hello') + await historian.record_external_event('chat', 'p2', 'put', 'hi') + await historian.record_external_event('chat', 'p1', 'put', 'bye') + + result = await workflow + assert result == [('p1', 'hello'), ('p2', 'hi'), ('p1', 'bye')] + + # After removing p1 -> when p1 tries to send message, it should raise KeyError + # with pytest.raises(KeyError): + # await historian.record_external_event('chat', 'p1', 'put', 'should not be received') + + +@pytest.mark.asyncio +@timeout(3) +async def test_multiqueue_single_response(): + received = {} + + async def player_workflow(): + players = {'p1': 'user1', 'p2': 'user2'} + async with MultiQueue('chat', players, single_response=True) as mq: + async for ident, msg in mq: + received[ident] = msg + return received + + historian = Historian('test', player_workflow, [], serializer=NoopSerializer()) + workflow = historian.run() + + await asyncio.sleep(0.1) + + await historian.record_external_event('chat', 'p1', 'put', 'hello') + await historian.record_external_event('chat', 'p2', 'put', 'hi') + # Second message from p1 - should be ignored due to single_response = True + await historian.record_external_event('chat', 'p1', 'put', 'should not be received') + + result = await workflow + assert result == {'p1': 'hello', 'p2': 'hi'} + + """ gate = asyncio.Event() diff --git a/src/quest/external.py b/src/quest/external.py index 9e4ea73..8a7e963 100644 --- a/src/quest/external.py +++ b/src/quest/external.py @@ -63,6 +63,7 @@ def value(self): class IdentityQueue: """Put and Get return and identity + the value""" + def __init__(self, *args, **kwargs): self._queue = asyncio.Queue(*args, **kwargs) @@ -110,6 +111,83 @@ def state(name, identity, value): def identity_queue(name): return InternalResource(name, None, IdentityQueue()) + +class MultiQueue: + def __init__(self, name: str, players: dict[str, str], single_response: bool = False): + self.queues: dict[str, InternalResource[Queue]] = {ident: queue(name, ident) for ident in players} + self.single_response = single_response + self.task_to_ident: dict[asyncio.Task, str] = {} + self.ident_to_task: dict[str, asyncio.Task] = {} + + # Hold unwrapped Queue objects after __aenter__ + self.active_queues: dict[str, Queue] = {} + + def _add_task(self, ident: str, q: Queue): + historian = find_historian() + task = historian.start_task( + q.get, + name=f"mq-get-{ident}" + ) + + self.task_to_ident[task] = ident + self.ident_to_task[ident] = task + + async def __aenter__(self): + # Listen on all queues -> create a task for each queue.get() + for ident, wrapper in self.queues.items(): + # Unwrap queue object + queue_obj = await wrapper.__aenter__() + self.active_queues[ident] = queue_obj + self._add_task(ident, queue_obj) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Cancel all pending tasks - context exits + for task in self.task_to_ident: + task.cancel() + # Exit all queues properly + for ident, wrapper in self.queues.items(): + await wrapper.__aexit__(exc_type, exc_val, exc_tb) + + async def remove(self, ident: str): + # Stop listening to this identity queue + task = self.ident_to_task.pop(ident, None) + + if task is not None: + self.task_to_ident.pop(task) + task.cancel() + + # Call __aexit__ on the corresponding queue wrapper + wrapper = self.queues.pop(ident, None) + if wrapper: + await wrapper.__aexit__(None, None, None) + + self.active_queues.pop(ident, None) + + async def __aiter__(self): + while self.task_to_ident: + # Wait until any of the current task is done + done, _ = await asyncio.wait(self.task_to_ident.keys(), return_when=asyncio.FIRST_COMPLETED) + + for task in done: + ident = self.task_to_ident.pop(task) + # Stop listening to this identity + del self.ident_to_task[ident] + + try: + result = await task + yield ident, result + + # Start listening again + if not self.single_response: + q = self.active_queues.get(ident) + if q: + self._add_task(ident, q) + + except asyncio.CancelledError: + continue + + class _ResourceWrapper: def __init__(self, name: str, identity: str | None, historian: 'Historian', resource_class): self._name = name @@ -129,14 +207,18 @@ async def wrapper(*args, _name=self._name, _identity=self._identity, **kwargs): return wrapper + def wrap_as_queue(name: str, identity: str | None, historian: Historian) -> Queue: return _ResourceWrapper(name, identity, historian, Queue) + def wrap_as_event(name: str, identity: str | None, historian: Historian) -> Event: return _ResourceWrapper(name, identity, historian, Event) + def wrap_as_state(name: str, identity: str | None, historian: Historian) -> State: return _ResourceWrapper(name, identity, historian, State) + def wrap_as_identity_queue(name: str, identity: str | None, historian: Historian) -> IdentityQueue: - return _ResourceWrapper(name, identity, historian, IdentityQueue) \ No newline at end of file + return _ResourceWrapper(name, identity, historian, IdentityQueue)