diff --git a/quest_test/test_external_actions.py b/quest_test/test_external_actions.py index 89fb2bb..93c3710 100644 --- a/quest_test/test_external_actions.py +++ b/quest_test/test_external_actions.py @@ -334,7 +334,8 @@ async def player_workflow(): await historian.record_external_event('chat', 'p1', 'put', 'bye') result = await workflow - assert result == [('p1', 'hello'), ('p2', 'hi'), ('p1', 'bye')] + assert result[-1] == ('p1', 'bye') + assert set(result[:2]) == {('p1', 'hello'), ('p2', 'hi')} # After removing p1 -> when p1 tries to send message, it should raise KeyError # with pytest.raises(KeyError): diff --git a/src/quest/__init__.py b/src/quest/__init__.py index 33d35b3..e56b0b7 100644 --- a/src/quest/__init__.py +++ b/src/quest/__init__.py @@ -3,7 +3,8 @@ from .context import these from .external import state, queue, identity_queue, event -from .historian import Historian, suspendable +from .historian import Historian +from .historian_context import suspendable from .history import History from .manager import WorkflowManager, WorkflowFactory from .manager_wrappers import alias diff --git a/src/quest/external.py b/src/quest/external.py index 8a7e963..ec83892 100644 --- a/src/quest/external.py +++ b/src/quest/external.py @@ -2,7 +2,9 @@ import uuid from typing import TypeVar, Generic -from .historian import Historian, find_historian, SUSPENDED, wrap_methods_as_historian_events +from .historian import Historian +from .historian_context import find_historian, SUSPENDED +from .historian_resources import wrap_methods_as_historian_events # diff --git a/src/quest/historian.py b/src/quest/historian.py index ee5f061..5db5de4 100644 --- a/src/quest/historian.py +++ b/src/quest/historian.py @@ -1,16 +1,20 @@ import asyncio -import inspect import traceback from asyncio import Task -from contextvars import ContextVar -from datetime import datetime from functools import wraps -from typing import Callable, TypeVar +from typing import Callable from .history import History +from .historian_context import SUSPENDED, historian_context +from .historian_evolution import EvolutionRuntime, EvolutionRuntimeContext +from .historian_helpers import ( + get_function_name, + _get_current_timestamp, + _get_id, +) +from .historian_resources import ResourceRuntime, ResourceRuntimeContext from .quest_types import ConfigurationRecord, VersionRecord, StepStartRecord, StepEndRecord, \ - ResourceAccessEvent, ResourceEntry, ResourceLifecycleEvent, TaskEvent -from .resources import ResourceStreamManager + ResourceAccessEvent, TaskEvent from .serializer import StepSerializer from .utils import quest_logger, task_name_getter from .utils import ( @@ -77,74 +81,10 @@ # Each task and step is a separate branch # I need to look for resources that are open in each branch and match the relevant events -SUSPENDED = '__WORKFLOW_SUSPENDED__' - - -def suspendable(func): - """ - Makes a __aexit__ or __exit__ method suspendable - With this decorator, the exit method will not be called - when the workflow is suspending. - It will only be called when the with context exits for other reasons. - """ - if inspect.iscoroutinefunction(func): - @wraps(func) - async def new_func(self, exc_type, exc_val, exc_tb): - if exc_type is asyncio.CancelledError and exc_val.args[0] == SUSPENDED: - return - await func(self, exc_type, exc_val, exc_tb) - else: - @wraps(func) - def new_func(self, exc_type, exc_val, exc_tb): - if exc_type is asyncio.CancelledError and exc_val.args[0] == SUSPENDED: - return - func(self, exc_type, exc_val, exc_tb) - - return new_func - - -T = TypeVar('T') - - class _Wrapper: pass -def wrap_methods_as_historian_events(resource: T, name: str, identity: str | None, historian: 'Historian', - internal=True) -> T: - wrapper = _Wrapper() - - historian_action = historian.handle_internal_event if internal else historian.record_external_event - - for field in dir(resource): - if field.startswith('_'): - continue - - if callable(method := getattr(resource, field)): - # Use default-value kwargs to force value binding instead of late binding - @wraps(method) - async def record(*args, _name=name, _identity=identity, _field=field, **kwargs): - return await historian_action(_name, _identity, _field, *args, **kwargs) - - setattr(wrapper, field, record) - - return wrapper - - -def _get_type_name(obj): - return obj.__class__.__module__ + '.' + obj.__class__.__name__ - - -def _get_id(item): - if isinstance(item, dict): - return tuple((k, _get_id(v)) for k, v in item.items()) - - if isinstance(item, list): - return tuple(_get_id(v) for v in item) - - return item - - def _prune(step_id: str, history: "History"): """ Remove substep work @@ -181,39 +121,6 @@ def _prune(step_id: str, history: "History"): for record in to_delete: history.remove(record) - - -def _get_current_timestamp() -> str: - return datetime.utcnow().isoformat() - - -def _get_qualified_version(module_name, function_name, version_name: str) -> str: - """ - A version is defined by the module and name of the function that is versioned. - If you move a function to a new module (or rename the module), it has become a new function. - If you change the function you are calling in a replay (i.e. its name has changed), - you may not get the expected results. - - You've been warned. - """ - return '.'.join([module_name, function_name, version_name]) - - -# Resource names should be unique to the workflow and identity -def _create_resource_id(name: str, identity: str | None) -> str: - return f'{name}|{identity}' if identity is not None else name - - -historian_context = ContextVar('historian') - - -def get_function_name(func): - if hasattr(func, '__name__'): # regular functions - return func.__name__ - else: # Callable classes - return func.__class__.__name__ - - def _get_exception_class(exception_type: str): module_name, class_name = exception_type.rsplit('.', 1) module = __import__(module_name, fromlist=[class_name]) @@ -226,8 +133,6 @@ def __init__(self, workflow_id: str, workflow: Callable, history: History, seria # TODO - change nomenclature (away from workflow)? Maybe just use workflow.__name__? self.workflow_id = workflow_id self.workflow = workflow - self._configurations: list[tuple[Callable, list, dict]] = [] - # This indicates whether the workflow function has completed # Suspending the workflow does not affect this value self._workflow_completed = False @@ -245,29 +150,6 @@ def __init__(self, workflow_id: str, workflow: Callable, history: History, seria # See also get_resources() and _run_with_exception_handling() self._fatal_exception = asyncio.Future() - # Keep track of configuration position during the replay - self._configuration_pos = 0 - - # Keep track of the versions of the workflow function - self._versions = {} - - # Keep track of the discovered, unprocessed versions - # While the code runs, it finds functions that are versioned - # However, we only process the versions during live play - # (not replay), so we need to save these until the replay is complete - # Then we add events to the history to record the new versions. - # As the workflow proceeds, it will now use the latest versions observed. - self._discovered_versions = {} - - # These are the resources available to the outside world. - # This could include values that can be accessed, - # queues to push to, etc. - # See also external.py - self._resources: dict[str, ResourceEntry] = {} - - # This is the resource stream manager that handles calls to stream the historian's resources - self._resource_stream_manager = ResourceStreamManager() - # We keep track of all open tasks so we can properly suspend them self._open_tasks: list[Task] = [] @@ -310,15 +192,32 @@ def __init__(self, workflow_id: str, workflow: Callable, history: History, seria # noinspection PyTypeChecker self._last_record_gate: asyncio.Future = None + self._resource_runtime = ResourceRuntime(ResourceRuntimeContext( + history=self._history, + replay_started=self._replay_started, + fatal_exception=self._fatal_exception, + get_next_record=self._next_record, + replay_complete=self._replay_complete, + get_task_name=self._get_task_name, + make_unique_id=self._get_unique_id, + )) + self._evolution_runtime = EvolutionRuntime(EvolutionRuntimeContext( + history=self._history, + get_external_task_name=self._get_external_task_name, + get_task_name=self._get_task_name, + get_next_record=self._next_record, + replay_has_completed=lambda: self._last_record_gate is not None and self._last_record_gate.done(), + existing_history=lambda: self._existing_history, + record_gates=lambda: self._record_gates, + )) + def _reset_replay(self): quest_logger.debug('Resetting replay') - self._configuration_pos = 0 - - self._versions = {} + self._evolution_runtime.reset() self._existing_history = list(self._history) - self._resources = {} + self._resource_runtime.reset() # The workflow ID is used as the task name for the root task self._prefix = { @@ -358,7 +257,7 @@ async def _replay_complete(self): quest_logger.debug(f'{self.workflow_id} -- Replay Complete --') # TODO - log this only once? - self._process_discovered_versions() + self._evolution_runtime.process_discovered_versions() def _get_external_task_name(self): return f'{self.workflow_id}.external' @@ -446,13 +345,13 @@ async def _external_handler(self): async for next_record in self._task_replay_records(self._get_external_task_name()): with next_record as record: if record['type'] == 'external': - await self._replay_external_event(record) + await self._resource_runtime.replay_external_event(record) elif record['type'] == 'set_version': - self._replay_version(record) + self._evolution_runtime.replay_version(record) elif record['type'] == 'configuration': - await self._run_configuration(record) + await self._evolution_runtime.run_configuration(record) quest_logger.debug(f'External event handler {self._get_task_name()} completed') except Exception: @@ -471,88 +370,26 @@ async def _next_record(self): return None async def _run_configuration(self, config_record: ConfigurationRecord): - config_function, args, kwargs = self._configurations[self._configuration_pos] - quest_logger.debug(f'Running configuration: {get_function_name(config_function)}(*{args}, **{kwargs})') - - assert config_record['function_name'] == get_function_name(config_function), str(config_record) - assert config_record['args'] == args, str(config_record) - assert config_record['kwargs'] == kwargs, str(config_record) - - await config_function(*args, **kwargs) - self._configuration_pos += 1 + await self._evolution_runtime.run_configuration(config_record) def get_version(self, module_name, function_name, version_name=GLOBAL_VERSION): - version = self._versions.get(_get_qualified_version(module_name, function_name, version_name), None) - quest_logger.debug( - f'{self._get_task_name()} get_version({module_name}, {function_name}, {version_name} returned "{version}"') - return version + return self._evolution_runtime.get_version(module_name, function_name, version_name) def _discover_versions(self, function, versions: dict[str, str]): - self._discovered_versions.update({ - _get_qualified_version(function.__module__, function.__qualname__, version_name): version - for version_name, version in versions.items() - }) - - # If the replay has already finished... - if self._last_record_gate is not None and self._last_record_gate.done(): - self._process_discovered_versions() + self._evolution_runtime.discover_versions(function, versions) def _process_discovered_versions(self): - for version_name, version in self._discovered_versions.items(): - # These records are replayed by the external handler - self._record_version_event(version_name, version) - self._discovered_versions = {} + self._evolution_runtime.process_discovered_versions() def _record_version_event(self, version_name, version): - if self._versions.get(version_name, None) == version: - return # Version not changed - - quest_logger.debug(f'Version record: {version_name} = {version}') - self._versions[version_name] = version - - self._history.append(VersionRecord( - type='set_version', - timestamp=_get_current_timestamp(), - step_id=version_name, - task_id=self._get_external_task_name(), # the external task owns these - version=version - )) + self._evolution_runtime.record_version_event(version_name, version) def _replay_version(self, record: VersionRecord): - quest_logger.debug(f'{self._get_task_name()} setting version {record["step_id"]} = "{record["version"]}"') - self._versions[record['step_id']] = record['version'] + self._evolution_runtime.replay_version(record) # TODO - keep or discard? async def _after_version(self, module_name, func_name, version_name, version): - version_name = _get_qualified_version(module_name, func_name, version_name) - quest_logger.debug(f'{self._get_task_name()} is waiting for version {version_name}=={version}') - - found = False - for record in self._existing_history: - if record['type'] == 'version' \ - and record['version_name'] == version_name \ - and record['version'] == version: - found = True - await self._record_gates[_get_id(record)] - - if not found: - quest_logger.error(f'{self._get_task_name()} did not find version {version_name}=={version}') - raise Exception(f'{self._get_task_name()} did not find version {version_name}=={version}') - - if (next_record := await self._next_record()) is not None: - with next_record as record: - assert record['type'] == 'after_version', str(record) - assert record['version_name'] == version_name, str(record) - assert record['version'] == version, str(record) - - else: - self._history.append(VersionRecord( - type='after_version', - timestamp=_get_current_timestamp(), - step_id='version', - task_id=self._get_task_name(), - version=version - )) + await self._evolution_runtime.after_version(module_name, func_name, version_name, version) async def handle_step(self, func_name, func: Callable, *args, **kwargs): step_id = self._get_unique_id(func_name) @@ -645,177 +482,19 @@ async def handle_step(self, func_name, func: Callable, *args, **kwargs): self._prefix[self._get_task_name()].pop(-1) async def record_external_event(self, name, identity, action, *args, **kwargs): - """ - When an external event occurs, this method is called. - """ - resource_id = _create_resource_id(name, identity) - step_id = self._get_unique_id(resource_id + '.' + action) - - quest_logger.debug(f'External event {step_id} with {args} and {kwargs}') - - resource = self._resources[resource_id]['resource'] - - function = getattr(resource, action) - if inspect.iscoroutinefunction(function): - result = await function(*args, **kwargs) - else: - result = function(*args, **kwargs) - - self._history.append(ResourceAccessEvent( - type='external', - timestamp=_get_current_timestamp(), - step_id=step_id, - task_id=self._get_task_name(), - resource_id=resource_id, - action=action, - args=list(args), - kwargs=kwargs, - result=result - )) - - return result + return await self._resource_runtime.record_external_event(name, identity, action, *args, **kwargs) async def _replay_external_event(self, record: ResourceAccessEvent): - """ - When an external event is replayed, this method is called - """ - assert record['type'] == 'external', str(record) - - result = getattr( - self._resources[record['resource_id']]['resource'], - record['action'] - )(*record['args'], **record['kwargs']) - - if inspect.iscoroutine(result): - result = await result - - assert result == record['result'] + await self._resource_runtime.replay_external_event(record) async def handle_internal_event(self, name, identity, action, *args, **kwargs): - """ - Internal events are always played - If the event is replayed, the details are asserted - If the event is new, it is recorded - """ - resource_id = _create_resource_id(name, identity) - step_id = self._get_unique_id(resource_id + '.' + action) - - resource = self._resources[resource_id]['resource'] - function = getattr(resource, action) - - if (next_record := await self._next_record()) is None: - self._history.append(ResourceAccessEvent( - type='internal_start', - timestamp=_get_current_timestamp(), - step_id=step_id, - task_id=self._get_task_name(), - resource_id=resource_id, - action=action, - args=list(args), - kwargs=kwargs, - result=None - )) - else: - with next_record as record: - assert 'internal_start' == record['type'], str(record) - assert resource_id == record['resource_id'], str(record) - assert action == record['action'], str(record) - assert list(args) == list(record['args']), str(record) - assert kwargs == record['kwargs'], str(record) - - quest_logger.debug(f'Calling {step_id} with {args} and {kwargs}') - if inspect.iscoroutinefunction(function): - result = await function(*args, **kwargs) - else: - result = function(*args, **kwargs) - - if (next_record := await self._next_record()) is None: - self._history.append(ResourceAccessEvent( - type='internal_end', - timestamp=_get_current_timestamp(), - step_id=step_id, - task_id=self._get_task_name(), - resource_id=resource_id, - action=action, - args=list(args), - kwargs=kwargs, - result=result - )) - await self._update_resource_stream(identity) - - else: - with next_record as record: - assert 'internal_end' == record['type'], f'internal != {record["type"]}' - assert resource_id == record['resource_id'] - assert action == record['action'] - assert list(args) == list(record['args']) - assert kwargs == record['kwargs'] - assert result == record['result'] - - return result + return await self._resource_runtime.handle_internal_event(name, identity, action, *args, **kwargs) async def register_resource(self, name, identity, resource): - resource_id = _create_resource_id(name, identity) - # TODO - support the ability to limit the exposed API on the resource - - if resource_id in self._resources: - raise Exception(f'A resource for {identity} named {name} already exists in this workflow') - # TODO - custom exception - - step_id = self._get_unique_id(resource_id + '.' + '__init__') - quest_logger.debug(f'Creating {resource_id}') - - self._resources[resource_id] = ResourceEntry( - name=name, - identity=identity, - type=_get_type_name(resource), - resource=resource - ) - - if (next_record := await self._next_record()) is None: - self._history.append(ResourceLifecycleEvent( - type='create_resource', - timestamp=_get_current_timestamp(), - step_id=step_id, - task_id=self._get_task_name(), - resource_id=resource_id, - resource_type=_get_type_name(resource) - )) - await self._update_resource_stream(identity) - - else: - with next_record as record: - assert record['type'] == 'create_resource' - assert record['resource_id'] == resource_id - - return resource_id + return await self._resource_runtime.register_resource(name, identity, resource) async def delete_resource(self, name, identity, suspending=False): - resource_id = _create_resource_id(name, identity) - if resource_id not in self._resources: - raise Exception(f'No resource for {identity} named {name} found') - # TODO - custom exception - - step_id = self._get_unique_id(resource_id + '.' + '__del__') - quest_logger.debug(f'Removing {resource_id}') - resource_entry = self._resources.pop(resource_id) - - if not suspending: - if (next_record := await self._next_record()) is None: - self._history.append(ResourceLifecycleEvent( - type='delete_resource', - timestamp=_get_current_timestamp(), - step_id=step_id, - task_id=self._get_task_name(), - resource_id=resource_id, - resource_type=resource_entry['type'] - )) - await self._update_resource_stream(identity) - - else: - with next_record as record: - assert record['type'] == 'delete_resource' - assert record['resource_id'] == resource_id + await self._resource_runtime.delete_resource(name, identity, suspending=suspending) def start_task(self, func, *args, name=None, task_factory=asyncio.create_task, **kwargs): historian_context.set(self) @@ -874,7 +553,7 @@ async def _run_with_args(self, *args, **kwargs): kwargs = await self.handle_step('kwargs', lambda: kwargs) result = await self.handle_step(get_function_name(self.workflow), self.workflow, *args, **kwargs) self._workflow_completed = True - self._resource_stream_manager.notify_of_workflow_stop() + self._resource_runtime.notify_of_workflow_stop() return result async def _run_with_exception_handling(self, *args, **kwargs): @@ -929,49 +608,15 @@ def run(self, *args, **kwargs): return task def configure(self, config_function, *args, **kwargs): - """ - Configuration happens when the application is run, before the run function is called. - Here we inject a configuration event into the history, - which is processed by the external task - """ - if not callable(config_function): - raise Exception(f'First argument to configure must be a callable. Received {config_function}.') - - self._configurations.append((config_function, list(args), kwargs)) + self._evolution_runtime.configure(config_function, *args, **kwargs) def _add_new_configurations(self): - config_records = [ - record - for record in self._history - if record['type'] == 'configuration' - ] - - # We should have a configuration to replay for each record in the past - assert len(config_records) <= len(self._configurations) - - for record, (config_function, args, kwargs) in zip(config_records, self._configurations): - assert record['function_name'] == get_function_name(config_function) - assert record['args'] == args - assert record['kwargs'] == kwargs - - # Add new configuration records - for config_function, args, kwargs in self._configurations[len(config_records):]: - quest_logger.debug(f'Adding new configuration: {get_function_name(config_function)}(*{args}, **{kwargs}') - - self._history.append(ConfigurationRecord( - type='configuration', - timestamp=_get_current_timestamp(), - step_id='configuration', - task_id=self._get_external_task_name(), # the external task owns these - function_name=get_function_name(config_function), - args=args, - kwargs=kwargs - )) + self._evolution_runtime.add_new_configurations() def signal_suspend(self): quest_logger.debug(f'-- Suspending {self.workflow_id} --') - self._resource_stream_manager.notify_of_workflow_stop() + self._resource_runtime.notify_of_workflow_stop() # Cancelling these in reverse order is important # If a parent thread cancels, it will cancel a child. @@ -996,46 +641,18 @@ async def suspend(self): pass async def get_resources(self, identity): - # Wait until the replay is done. - # This ensures that all pre-existing resources have been rebuilt. - await self._replay_started.wait() - await self._replay_complete() - - # If the application has failed, let the caller know - if self._fatal_exception.done(): - await self._fatal_exception - - resources: dict[(str, str), str] = {} # dict[(name, identity), type] - for entry in self._resources.values(): - # Always return public resources and private resources for the specified identity - if entry['identity'] is None or entry['identity'] == identity: - resources[(entry['name'], entry['identity'])] = entry['type'] - - return resources + return await self._resource_runtime.get_resources(identity) def get_resource_stream(self, identity): - return self._resource_stream_manager.get_resource_stream( - identity, - lambda: self.get_resources(identity), - ) + return self._resource_runtime.get_resource_stream(identity) async def _update_resource_stream(self, identity): - await self._resource_stream_manager.update(identity) - - -class HistorianNotFoundException(Exception): - pass - + await self._resource_runtime.update_resource_stream(identity) -def find_historian() -> Historian: - if (workflow := historian_context.get()) is not None: - return workflow + @property + def _resources(self): + return self._resource_runtime._resources - outer_frame = inspect.currentframe() - is_workflow = False - while not is_workflow: - outer_frame = outer_frame.f_back - if outer_frame is None: - raise HistorianNotFoundException("Historian object not found in event stack") - is_workflow = isinstance(outer_frame.f_locals.get('self'), Historian) - return outer_frame.f_locals.get('self') + @property + def _resource_stream_manager(self): + return self._resource_runtime._resource_stream_manager diff --git a/src/quest/historian_context.py b/src/quest/historian_context.py new file mode 100644 index 0000000..2efabe9 --- /dev/null +++ b/src/quest/historian_context.py @@ -0,0 +1,52 @@ +import asyncio +import inspect +from contextvars import ContextVar +from functools import wraps + +SUSPENDED = '__WORKFLOW_SUSPENDED__' + +historian_context = ContextVar('historian', default=None) + + +class HistorianNotFoundException(Exception): + pass + + +def suspendable(func): + """ + Makes a __aexit__ or __exit__ method suspendable + With this decorator, the exit method will not be called + when the workflow is suspending. + It will only be called when the with context exits for other reasons. + """ + if inspect.iscoroutinefunction(func): + @wraps(func) + async def new_func(self, exc_type, exc_val, exc_tb): + if exc_type is asyncio.CancelledError and exc_val.args[0] == SUSPENDED: + return + await func(self, exc_type, exc_val, exc_tb) + else: + @wraps(func) + def new_func(self, exc_type, exc_val, exc_tb): + if exc_type is asyncio.CancelledError and exc_val.args[0] == SUSPENDED: + return + func(self, exc_type, exc_val, exc_tb) + + return new_func + + +def find_historian(): + workflow = historian_context.get() + if workflow is not None: + return workflow + + from .historian import Historian + + outer_frame = inspect.currentframe() + is_workflow = False + while not is_workflow: + outer_frame = outer_frame.f_back + if outer_frame is None: + raise HistorianNotFoundException("Historian object not found in event stack") + is_workflow = isinstance(outer_frame.f_locals.get('self'), Historian) + return outer_frame.f_locals.get('self') diff --git a/src/quest/historian_evolution.py b/src/quest/historian_evolution.py new file mode 100644 index 0000000..630d209 --- /dev/null +++ b/src/quest/historian_evolution.py @@ -0,0 +1,145 @@ +from dataclasses import dataclass +from typing import Any, Callable + +from .historian_helpers import get_function_name, _get_current_timestamp, _get_id, _get_qualified_version +from .quest_types import ConfigurationRecord, VersionRecord +from .utils import quest_logger + +GLOBAL_VERSION = "_global_version" + + +@dataclass(slots=True) +class EvolutionRuntimeContext: + history: list + get_external_task_name: Callable[[], str] + get_task_name: Callable[[], str] + get_next_record: Callable[[], Any] + replay_has_completed: Callable[[], bool] + existing_history: Callable[[], list] + record_gates: Callable[[], dict] + + +class EvolutionRuntime: + def __init__(self, context: EvolutionRuntimeContext): + self._context = context + self._configurations: list[tuple[Callable, list, dict]] = [] + self._configuration_pos = 0 + self._versions = {} + self._discovered_versions = {} + + def reset(self): + self._configuration_pos = 0 + self._versions = {} + + def configure(self, config_function, *args, **kwargs): + if not callable(config_function): + raise Exception(f'First argument to configure must be a callable. Received {config_function}.') + + self._configurations.append((config_function, list(args), kwargs)) + + def add_new_configurations(self): + config_records = [ + record + for record in self._context.history + if record['type'] == 'configuration' + ] + + assert len(config_records) <= len(self._configurations) + + for record, (config_function, args, kwargs) in zip(config_records, self._configurations): + assert record['function_name'] == get_function_name(config_function) + assert record['args'] == args + assert record['kwargs'] == kwargs + + for config_function, args, kwargs in self._configurations[len(config_records):]: + quest_logger.debug(f'Adding new configuration: {get_function_name(config_function)}(*{args}, **{kwargs}') + + self._context.history.append(ConfigurationRecord( + type='configuration', + timestamp=_get_current_timestamp(), + step_id='configuration', + task_id=self._context.get_external_task_name(), + function_name=get_function_name(config_function), + args=args, + kwargs=kwargs + )) + + async def run_configuration(self, config_record: ConfigurationRecord): + config_function, args, kwargs = self._configurations[self._configuration_pos] + quest_logger.debug(f'Running configuration: {get_function_name(config_function)}(*{args}, **{kwargs})') + + assert config_record['function_name'] == get_function_name(config_function), str(config_record) + assert config_record['args'] == args, str(config_record) + assert config_record['kwargs'] == kwargs, str(config_record) + + await config_function(*args, **kwargs) + self._configuration_pos += 1 + + def get_version(self, module_name, function_name, version_name=GLOBAL_VERSION): + version = self._versions.get(_get_qualified_version(module_name, function_name, version_name), None) + quest_logger.debug( + f'{self._context.get_task_name()} get_version({module_name}, {function_name}, {version_name} returned "{version}"') + return version + + def discover_versions(self, function, versions: dict[str, str]): + self._discovered_versions.update({ + _get_qualified_version(function.__module__, function.__qualname__, version_name): version + for version_name, version in versions.items() + }) + + if self._context.replay_has_completed(): + self.process_discovered_versions() + + def process_discovered_versions(self): + for version_name, version in self._discovered_versions.items(): + self.record_version_event(version_name, version) + self._discovered_versions = {} + + def record_version_event(self, version_name, version): + if self._versions.get(version_name, None) == version: + return + + quest_logger.debug(f'Version record: {version_name} = {version}') + self._versions[version_name] = version + + self._context.history.append(VersionRecord( + type='set_version', + timestamp=_get_current_timestamp(), + step_id=version_name, + task_id=self._context.get_external_task_name(), + version=version + )) + + def replay_version(self, record: VersionRecord): + quest_logger.debug(f'{self._context.get_task_name()} setting version {record["step_id"]} = "{record["version"]}"') + self._versions[record['step_id']] = record['version'] + + async def after_version(self, module_name, func_name, version_name, version): + version_name = _get_qualified_version(module_name, func_name, version_name) + quest_logger.debug(f'{self._context.get_task_name()} is waiting for version {version_name}=={version}') + + found = False + for record in self._context.existing_history(): + if record['type'] == 'version' \ + and record['version_name'] == version_name \ + and record['version'] == version: + found = True + await self._context.record_gates()[_get_id(record)] + + if not found: + quest_logger.error(f'{self._context.get_task_name()} did not find version {version_name}=={version}') + raise Exception(f'{self._context.get_task_name()} did not find version {version_name}=={version}') + + if (next_record := await self._context.get_next_record()) is not None: + with next_record as record: + assert record['type'] == 'after_version', str(record) + assert record['version_name'] == version_name, str(record) + assert record['version'] == version, str(record) + else: + self._context.history.append(VersionRecord( + type='after_version', + timestamp=_get_current_timestamp(), + step_id='version', + task_id=self._context.get_task_name(), + version=version + )) diff --git a/src/quest/historian_helpers.py b/src/quest/historian_helpers.py new file mode 100644 index 0000000..14a758b --- /dev/null +++ b/src/quest/historian_helpers.py @@ -0,0 +1,41 @@ +from datetime import datetime + + +def get_function_name(func): + if hasattr(func, '__name__'): + return func.__name__ + return func.__class__.__name__ + + +def _get_id(item): + if isinstance(item, dict): + return tuple((k, _get_id(v)) for k, v in item.items()) + + if isinstance(item, list): + return tuple(_get_id(v) for v in item) + + return item + + +def _get_current_timestamp() -> str: + return datetime.utcnow().isoformat() + + +def _create_resource_id(name: str, identity: str | None) -> str: + return f'{name}|{identity}' if identity is not None else name + + +def _get_type_name(obj): + return obj.__class__.__module__ + '.' + obj.__class__.__name__ + + +def _get_qualified_version(module_name, function_name, version_name: str) -> str: + """ + A version is defined by the module and name of the function that is versioned. + If you move a function to a new module (or rename the module), it has become a new function. + If you change the function you are calling in a replay (i.e. its name has changed), + you may not get the expected results. + + You've been warned. + """ + return '.'.join([module_name, function_name, version_name]) diff --git a/src/quest/historian_resources.py b/src/quest/historian_resources.py new file mode 100644 index 0000000..f7f0b23 --- /dev/null +++ b/src/quest/historian_resources.py @@ -0,0 +1,237 @@ +import inspect +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, TypeVar + +from .historian_helpers import _create_resource_id, _get_current_timestamp, _get_type_name +from .quest_types import ResourceAccessEvent, ResourceEntry, ResourceLifecycleEvent +from .resources import ResourceStreamManager +from .utils import quest_logger + +T = TypeVar('T') + + +class _Wrapper: + pass + + +def wrap_methods_as_historian_events(resource: T, name: str, identity: str | None, historian: Any, + internal=True) -> T: + wrapper = _Wrapper() + + historian_action = historian.handle_internal_event if internal else historian.record_external_event + + for field in dir(resource): + if field.startswith('_'): + continue + + if callable(method := getattr(resource, field)): + @wraps(method) + async def record(*args, _name=name, _identity=identity, _field=field, **kwargs): + return await historian_action(_name, _identity, _field, *args, **kwargs) + + setattr(wrapper, field, record) + + return wrapper + + +@dataclass(slots=True) +class ResourceRuntimeContext: + history: list + replay_started: Any + fatal_exception: Any + get_next_record: Callable[[], Any] + replay_complete: Callable[[], Any] + get_task_name: Callable[[], str] + make_unique_id: Callable[[str], str] + + +class ResourceRuntime: + def __init__(self, context: ResourceRuntimeContext): + self._context = context + self._resources: dict[str, ResourceEntry] = {} + self._resource_stream_manager = ResourceStreamManager() + + def reset(self): + self._resources = {} + + async def record_external_event(self, name, identity, action, *args, **kwargs): + resource_id = _create_resource_id(name, identity) + step_id = self._context.make_unique_id(resource_id + '.' + action) + + quest_logger.debug(f'External event {step_id} with {args} and {kwargs}') + + resource = self._resources[resource_id]['resource'] + function = getattr(resource, action) + if inspect.iscoroutinefunction(function): + result = await function(*args, **kwargs) + else: + result = function(*args, **kwargs) + + self._context.history.append(ResourceAccessEvent( + type='external', + timestamp=_get_current_timestamp(), + step_id=step_id, + task_id=self._context.get_task_name(), + resource_id=resource_id, + action=action, + args=list(args), + kwargs=kwargs, + result=result + )) + + return result + + async def replay_external_event(self, record: ResourceAccessEvent): + assert record['type'] == 'external', str(record) + + result = getattr( + self._resources[record['resource_id']]['resource'], + record['action'] + )(*record['args'], **record['kwargs']) + + if inspect.iscoroutine(result): + result = await result + + assert result == record['result'] + + async def handle_internal_event(self, name, identity, action, *args, **kwargs): + resource_id = _create_resource_id(name, identity) + step_id = self._context.make_unique_id(resource_id + '.' + action) + + resource = self._resources[resource_id]['resource'] + function = getattr(resource, action) + + if (next_record := await self._context.get_next_record()) is None: + self._context.history.append(ResourceAccessEvent( + type='internal_start', + timestamp=_get_current_timestamp(), + step_id=step_id, + task_id=self._context.get_task_name(), + resource_id=resource_id, + action=action, + args=list(args), + kwargs=kwargs, + result=None + )) + else: + with next_record as record: + assert 'internal_start' == record['type'], str(record) + assert resource_id == record['resource_id'], str(record) + assert action == record['action'], str(record) + assert list(args) == list(record['args']), str(record) + assert kwargs == record['kwargs'], str(record) + + quest_logger.debug(f'Calling {step_id} with {args} and {kwargs}') + if inspect.iscoroutinefunction(function): + result = await function(*args, **kwargs) + else: + result = function(*args, **kwargs) + + if (next_record := await self._context.get_next_record()) is None: + self._context.history.append(ResourceAccessEvent( + type='internal_end', + timestamp=_get_current_timestamp(), + step_id=step_id, + task_id=self._context.get_task_name(), + resource_id=resource_id, + action=action, + args=list(args), + kwargs=kwargs, + result=result + )) + await self.update_resource_stream(identity) + else: + with next_record as record: + assert 'internal_end' == record['type'], f'internal != {record["type"]}' + assert resource_id == record['resource_id'] + assert action == record['action'] + assert list(args) == list(record['args']) + assert kwargs == record['kwargs'] + assert result == record['result'] + + return result + + async def register_resource(self, name, identity, resource): + resource_id = _create_resource_id(name, identity) + + if resource_id in self._resources: + raise Exception(f'A resource for {identity} named {name} already exists in this workflow') + + step_id = self._context.make_unique_id(resource_id + '.' + '__init__') + quest_logger.debug(f'Creating {resource_id}') + + self._resources[resource_id] = ResourceEntry( + name=name, + identity=identity, + type=_get_type_name(resource), + resource=resource + ) + + if (next_record := await self._context.get_next_record()) is None: + self._context.history.append(ResourceLifecycleEvent( + type='create_resource', + timestamp=_get_current_timestamp(), + step_id=step_id, + task_id=self._context.get_task_name(), + resource_id=resource_id, + resource_type=_get_type_name(resource) + )) + await self.update_resource_stream(identity) + else: + with next_record as record: + assert record['type'] == 'create_resource' + assert record['resource_id'] == resource_id + + return resource_id + + async def delete_resource(self, name, identity, suspending=False): + resource_id = _create_resource_id(name, identity) + if resource_id not in self._resources: + raise Exception(f'No resource for {identity} named {name} found') + + step_id = self._context.make_unique_id(resource_id + '.' + '__del__') + quest_logger.debug(f'Removing {resource_id}') + resource_entry = self._resources.pop(resource_id) + + if not suspending: + if (next_record := await self._context.get_next_record()) is None: + self._context.history.append(ResourceLifecycleEvent( + type='delete_resource', + timestamp=_get_current_timestamp(), + step_id=step_id, + task_id=self._context.get_task_name(), + resource_id=resource_id, + resource_type=resource_entry['type'] + )) + await self.update_resource_stream(identity) + else: + with next_record as record: + assert record['type'] == 'delete_resource' + assert record['resource_id'] == resource_id + + async def get_resources(self, identity): + await self._context.replay_started.wait() + await self._context.replay_complete() + + if self._context.fatal_exception.done(): + await self._context.fatal_exception + + resources: dict[(str, str), str] = {} + for entry in self._resources.values(): + if entry['identity'] is None or entry['identity'] == identity: + resources[(entry['name'], entry['identity'])] = entry['type'] + + return resources + + def get_resource_stream(self, identity): + return self._resource_stream_manager.get_resource_stream( + identity, + lambda: self.get_resources(identity), + ) + + async def update_resource_stream(self, identity): + await self._resource_stream_manager.update(identity) + + def notify_of_workflow_stop(self): + self._resource_stream_manager.notify_of_workflow_stop() diff --git a/src/quest/manager.py b/src/quest/manager.py index 06c7ba4..c7327aa 100644 --- a/src/quest/manager.py +++ b/src/quest/manager.py @@ -6,7 +6,8 @@ from typing import Protocol, Callable, TypeVar, Any, TypedDict from .external import State, IdentityQueue, Queue, Event -from .historian import Historian, _Wrapper, SUSPENDED +from .historian import Historian, _Wrapper +from .historian_context import SUSPENDED from .history import History from .persistence import BlobStorage from .serializer import StepSerializer @@ -133,9 +134,13 @@ def _start_workflow(self, self._workflows[workflow_id] = historian self._workflow_tasks[workflow_id] = (task := historian.run(*workflow_args, **workflow_kwargs)) + task.add_done_callback(lambda t: self._schedule_store_result(workflow_id, t, delete_on_finish)) - # run _store_result asynchronously in the background - task.add_done_callback(lambda t: asyncio.create_task(self._store_result(workflow_id, t, delete_on_finish))) + def _schedule_store_result(self, workflow_id: str, task: asyncio.Task, delete_on_finish: bool): + try: + asyncio.create_task(self._store_result(workflow_id, task, delete_on_finish)) + except RuntimeError: + pass async def delete_workflow(self, workflow_id: str): """ @@ -155,23 +160,22 @@ async def delete_workflow(self, workflow_id: str): async def _store_result(self, workflow_id: str, task: asyncio.Task, delete_on_finish: bool): """Store the result or exception of a completed workflow""" - if ( - (ex := task.exception()) is not None - and isinstance(ex, asyncio.CancelledError) - and ex.args and ex.args[0] == SUSPENDED - ): - return - - if not delete_on_finish: + cancelled_error = None + try: + task.exception() + except asyncio.CancelledError as ex: + cancelled_error = ex + + if cancelled_error is not None: + if cancelled_error.args and cancelled_error.args[0] == SUSPENDED: + return + + elif not delete_on_finish: try: - # Retrieve the workflow result if it completed successfully result = task.result() - serialized_result = await self._serializer.serialize(result) - result = serialized_result - + result = await self._serializer.serialize(result) except BaseException as e: - serialized_exception = serialize_exception(e) - result = serialized_exception + result = serialize_exception(e) wdata = self._workflow_data[workflow_id] self._results[workflow_id] = WorkflowResult( @@ -181,10 +185,9 @@ async def _store_result(self, workflow_id: str, task: asyncio.Task, delete_on_fi result=result ) - # Completed workflow - del self._workflows[workflow_id] - del self._workflow_tasks[workflow_id] - del self._workflow_data[workflow_id] + self._workflows.pop(workflow_id, None) + self._workflow_tasks.pop(workflow_id, None) + self._workflow_data.pop(workflow_id, None) def start_workflow(self, workflow_type: str, workflow_id: str, *workflow_args, delete_on_finish: bool = True, **workflow_kwargs): diff --git a/src/quest/manager_wrappers.py b/src/quest/manager_wrappers.py index 697a799..0fc9e02 100644 --- a/src/quest/manager_wrappers.py +++ b/src/quest/manager_wrappers.py @@ -1,5 +1,5 @@ from .manager import find_workflow_manager -from .historian import find_historian +from .historian_context import find_historian class Alias: diff --git a/src/quest/versioning.py b/src/quest/versioning.py index f37f4bb..71ebe33 100644 --- a/src/quest/versioning.py +++ b/src/quest/versioning.py @@ -1,7 +1,8 @@ import inspect from functools import wraps -from .historian import GLOBAL_VERSION, QUEST_VERSIONS, find_historian +from .historian import GLOBAL_VERSION, QUEST_VERSIONS +from .historian_context import find_historian DEFAULT_VERSION = '' diff --git a/src/quest/wrappers.py b/src/quest/wrappers.py index 80ea45a..c61fe1a 100644 --- a/src/quest/wrappers.py +++ b/src/quest/wrappers.py @@ -3,7 +3,7 @@ from functools import wraps from typing import Callable, Coroutine, TypeVar -from .historian import find_historian +from .historian_context import find_historian def _get_func_name(func) -> str: