Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions demos/multi_guess_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from quest import (step, queue, state, identity_queue,
create_filesystem_manager, these)
from quest.server import Server
from quest.external import MultiQueue


@step
Expand Down Expand Up @@ -35,7 +36,6 @@ async def get_secret():
@step
async def get_guesses(players: dict[str, str], message) -> dict[str, int]:
guesses = {}

status_message = []

# TODO - the following code sequence is a little verbose
Expand All @@ -47,30 +47,16 @@ async def get_guesses(players: dict[str, str], message) -> dict[str, int]:
# This pattern should be common enough we should make
# it easy and clear

async with (
# Create a guess queue for each player
these({
ident: queue('guess', ident)
for ident in players
}) as guess_queues
):
# Wait for guesses to come in.
# As they do, remove their queue so they can't guess again.
guess_gets = {q.get(): ident for ident, q in guess_queues.items()}
for guess_get in asyncio.as_completed(guess_gets):
guess = await guess_get
ident = guess_gets[guess]
# Iterate guesses one at a time
async with MultiQueue('guess', players, single_response=True) as mq:
async for ident, guess in mq:
Comment on lines +50 to +52
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love how much more concise this is from what it used to be. :)

guesses[ident] = guess

# Update the status
name = players[ident]
status_message.append(f'{name} guessed {guess}')
message.set('\n'.join(status_message))

# Remove the queue
# The user will no longer see it
guess_queues.remove(ident)

return guesses


Expand Down
64 changes: 63 additions & 1 deletion quest_test/test_external_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from quest.external import state, queue, event, wrap_as_state, wrap_as_queue
from quest.external import state, queue, event, wrap_as_state, wrap_as_queue, MultiQueue
from quest.historian import Historian
from quest.wrappers import task, step
from quest.serializer import NoopSerializer
Expand Down Expand Up @@ -305,6 +305,68 @@ async def test_step_specific_external():
assert (await workflow) == 3


@pytest.mark.asyncio
@timeout(3)
async def test_multiqueue_default():
received = []

async def player_workflow():
players = {'p1': 'user1', 'p2': 'user2'}

async with MultiQueue('chat', players) as mq:
async for ident, msg in mq:
received.append((ident, msg))
# If player sends 'bye', remove their queue after their message is recorded
if msg == 'bye':
await mq.remove(ident)
# Exit the Multiqueue when 3 messages are recorded
if len(received) == 3:
break
return received

historian = Historian('test', player_workflow, [], serializer=NoopSerializer())
workflow = historian.run()

await asyncio.sleep(0.1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get resources for both p1 and p2, and show they both have "chats"

Then p1 says "bye", and show that p1 no longer has chats.

show that p2 has "chats" until the end of the workflow.

await historian.record_external_event('chat', 'p1', 'put', 'hello')
await historian.record_external_event('chat', 'p2', 'put', 'hi')
await historian.record_external_event('chat', 'p1', 'put', 'bye')

result = await workflow
assert result == [('p1', 'hello'), ('p2', 'hi'), ('p1', 'bye')]

# After removing p1 -> when p1 tries to send message, it should raise KeyError
# with pytest.raises(KeyError):
# await historian.record_external_event('chat', 'p1', 'put', 'should not be received')


@pytest.mark.asyncio
@timeout(3)
async def test_multiqueue_single_response():
received = {}

async def player_workflow():
players = {'p1': 'user1', 'p2': 'user2'}
async with MultiQueue('chat', players, single_response=True) as mq:
async for ident, msg in mq:
received[ident] = msg
return received

historian = Historian('test', player_workflow, [], serializer=NoopSerializer())
workflow = historian.run()

await asyncio.sleep(0.1)

await historian.record_external_event('chat', 'p1', 'put', 'hello')
await historian.record_external_event('chat', 'p2', 'put', 'hi')
# Second message from p1 - should be ignored due to single_response = True
await historian.record_external_event('chat', 'p1', 'put', 'should not be received')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should throw an exception.


result = await workflow
assert result == {'p1': 'hello', 'p2': 'hi'}


"""

gate = asyncio.Event()
Expand Down
84 changes: 83 additions & 1 deletion src/quest/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def value(self):

class IdentityQueue:
"""Put and Get return and identity + the value"""

def __init__(self, *args, **kwargs):
self._queue = asyncio.Queue(*args, **kwargs)

Expand Down Expand Up @@ -110,6 +111,83 @@ def state(name, identity, value):
def identity_queue(name):
return InternalResource(name, None, IdentityQueue())


class MultiQueue:
def __init__(self, name: str, players: dict[str, str], single_response: bool = False):
self.queues: dict[str, InternalResource[Queue]] = {ident: queue(name, ident) for ident in players}
self.single_response = single_response
self.task_to_ident: dict[asyncio.Task, str] = {}
self.ident_to_task: dict[str, asyncio.Task] = {}

# Hold unwrapped Queue objects after __aenter__
self.active_queues: dict[str, Queue] = {}

def _add_task(self, ident: str, q: Queue):
historian = find_historian()
task = historian.start_task(
q.get,
name=f"mq-get-{ident}"
)

self.task_to_ident[task] = ident
self.ident_to_task[ident] = task

async def __aenter__(self):
# Listen on all queues -> create a task for each queue.get()
for ident, wrapper in self.queues.items():
# Unwrap queue object
queue_obj = await wrapper.__aenter__()
self.active_queues[ident] = queue_obj
self._add_task(ident, queue_obj)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
# Cancel all pending tasks - context exits
for task in self.task_to_ident:
task.cancel()
Comment on lines +144 to +147
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to __aexit__ each queue as we exit.

# Exit all queues properly
for ident, wrapper in self.queues.items():
await wrapper.__aexit__(exc_type, exc_val, exc_tb)

async def remove(self, ident: str):
# Stop listening to this identity queue
task = self.ident_to_task.pop(ident, None)

if task is not None:
self.task_to_ident.pop(task)
task.cancel()

# Call __aexit__ on the corresponding queue wrapper
wrapper = self.queues.pop(ident, None)
if wrapper:
await wrapper.__aexit__(None, None, None)

Comment on lines +161 to +164
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrapper = self.queues.pop(ident)
await wrapper.__aexit__(None, None, None)

self.active_queues.pop(ident, None)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No None


async def __aiter__(self):
while self.task_to_ident:
# Wait until any of the current task is done
done, _ = await asyncio.wait(self.task_to_ident.keys(), return_when=asyncio.FIRST_COMPLETED)

for task in done:
ident = self.task_to_ident.pop(task)
# Stop listening to this identity
del self.ident_to_task[ident]

try:
result = await task
yield ident, result

# Start listening again
if not self.single_response:
q = self.active_queues.get(ident)
if q:
self._add_task(ident, q)

except asyncio.CancelledError:
continue


class _ResourceWrapper:
def __init__(self, name: str, identity: str | None, historian: 'Historian', resource_class):
self._name = name
Expand All @@ -129,14 +207,18 @@ async def wrapper(*args, _name=self._name, _identity=self._identity, **kwargs):

return wrapper


def wrap_as_queue(name: str, identity: str | None, historian: Historian) -> Queue:
return _ResourceWrapper(name, identity, historian, Queue)


def wrap_as_event(name: str, identity: str | None, historian: Historian) -> Event:
return _ResourceWrapper(name, identity, historian, Event)


def wrap_as_state(name: str, identity: str | None, historian: Historian) -> State:
return _ResourceWrapper(name, identity, historian, State)


def wrap_as_identity_queue(name: str, identity: str | None, historian: Historian) -> IdentityQueue:
return _ResourceWrapper(name, identity, historian, IdentityQueue)
return _ResourceWrapper(name, identity, historian, IdentityQueue)
Loading