diff --git a/src/orcapod/core/executors/capture_wrapper.py b/src/orcapod/core/executors/capture_wrapper.py index 9cf9665b..af3b94c3 100644 --- a/src/orcapod/core/executors/capture_wrapper.py +++ b/src/orcapod/core/executors/capture_wrapper.py @@ -12,9 +12,15 @@ from typing import Any -def make_capture_wrapper() -> Callable[..., Any]: +def make_capture_wrapper(name: str | None = None) -> Callable[..., Any]: """Return a capture wrapper suitable for remote execution. + Args: + name: If provided, the wrapper's ``__name__`` and ``__qualname__`` + are set to this value so that remote frameworks (e.g. Ray) report + the original function name in dashboards and metrics rather than + the generic ``_capture``. + On success the wrapper returns a 4-tuple ``(raw_result, stdout_log, stderr_log, python_logs)``. @@ -137,4 +143,8 @@ def emit(self, record: logging.LogRecord) -> None: return raw_result, cap_stdout, cap_stderr, python_logs + if name is not None: + _capture.__name__ = name + _capture.__qualname__ = name + return _capture diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index cf82253a..98da9e50 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -180,7 +180,7 @@ def execute_callable( import ray self._ensure_ray_initialized() - remote_fn = self._get_remote_fn(ray) + remote_fn = self._get_remote_fn(ray, fn.__name__) ref = remote_fn.options(name=fn.__name__).remote(fn, kwargs) try: @@ -206,7 +206,7 @@ async def async_execute_callable( import ray self._ensure_ray_initialized() - remote_fn = self._get_remote_fn(ray) + remote_fn = self._get_remote_fn(ray, fn.__name__) ref = remote_fn.options(name=fn.__name__).remote(fn, kwargs) try: @@ -225,26 +225,26 @@ async def async_execute_callable( self._record_success(stdout_log, stderr_log, python_logs, logger) return raw - def _get_remote_fn(self, ray: Any) -> Any: + def _get_remote_fn(self, ray: Any, fn_name: str) -> Any: """Return a cached Ray remote wrapper for the capture closure. - The capture wrapper's bytecode is identical on every invocation, so - it only needs to be remotized once per distinct set of remote options. - Caching avoids the non-trivial overhead of ``ray.remote()`` on every - packet. + The capture wrapper is created with ``fn_name`` so that Ray's + metrics and dashboard report the original function name instead + of the generic ``_capture``. Wrappers are cached per distinct + ``(remote_opts, fn_name)`` pair. A ``threading.Lock`` guards population so that concurrent calls (``supports_concurrent_execution = True``) never redundantly call ``ray.remote()`` for the same option set. """ opts = self._build_remote_opts() - cache_key = self._normalize_opts(opts) + cache_key = (self._normalize_opts(opts), fn_name) if cache_key not in self._remote_fn_cache: with self._remote_fn_cache_lock: # Double-checked: another thread may have filled the slot # while we waited for the lock. if cache_key not in self._remote_fn_cache: - wrapper = make_capture_wrapper() + wrapper = make_capture_wrapper(name=fn_name) self._remote_fn_cache[cache_key] = ray.remote(**opts)(wrapper) return self._remote_fn_cache[cache_key] diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index 6415385f..2b644689 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -640,6 +640,63 @@ def resolve_later(): ) assert result == 42 + def test_get_remote_fn_caches_per_function_name(self): + """_get_remote_fn must return distinct remote wrappers for different + function names so Ray metrics report the correct name.""" + from unittest.mock import MagicMock, call, patch + + from orcapod.core.executors.ray import RayExecutor + + mock_ray = MagicMock() + mock_ray.is_initialized.return_value = True + # ray.remote(**opts)(wrapper) returns a mock remote fn + mock_ray.remote.return_value = lambda wrapper: MagicMock(name=f"remote_{wrapper.__name__}") + + with patch.dict("sys.modules", {"ray": mock_ray}): + executor = RayExecutor.__new__(RayExecutor) + executor._remote_opts = {} + executor._remote_fn_cache = {} + executor._remote_fn_cache_lock = __import__("threading").Lock() + + fn_a = executor._get_remote_fn(mock_ray, "transform_a") + fn_b = executor._get_remote_fn(mock_ray, "transform_b") + fn_a_again = executor._get_remote_fn(mock_ray, "transform_a") + + # Different names → different remote fns + assert fn_a is not fn_b + # Same name → cached (same object) + assert fn_a is fn_a_again + + def test_get_remote_fn_sets_wrapper_name(self): + """The capture wrapper created by _get_remote_fn should carry the + original function name so Ray uses it in metrics labels.""" + from unittest.mock import MagicMock, patch + + from orcapod.core.executors.ray import RayExecutor + + captured_wrappers = [] + + def fake_remote(**opts): + def decorator(wrapper): + captured_wrappers.append(wrapper) + return MagicMock() + return decorator + + mock_ray = MagicMock() + mock_ray.is_initialized.return_value = True + mock_ray.remote = fake_remote + + with patch.dict("sys.modules", {"ray": mock_ray}): + executor = RayExecutor.__new__(RayExecutor) + executor._remote_opts = {} + executor._remote_fn_cache = {} + executor._remote_fn_cache_lock = __import__("threading").Lock() + + executor._get_remote_fn(mock_ray, "compute_features") + + assert len(captured_wrappers) == 1 + assert captured_wrappers[0].__name__ == "compute_features" + # =========================================================================== # 7. PacketFunctionExecutorProtocol type safety diff --git a/tests/test_pipeline/test_logging_capture.py b/tests/test_pipeline/test_logging_capture.py index 7ec2c855..57d39093 100644 --- a/tests/test_pipeline/test_logging_capture.py +++ b/tests/test_pipeline/test_logging_capture.py @@ -342,3 +342,33 @@ def fn(): os.write(1, b"") finally: os.close(original_stdout_fd) + + def test_wrapper_name_defaults_to_capture(self): + """Without a name argument the wrapper keeps its default name.""" + from orcapod.core.executors.capture_wrapper import make_capture_wrapper + + wrapper = make_capture_wrapper() + assert wrapper.__name__ == "_capture" + assert wrapper.__qualname__.endswith("_capture") + + def test_wrapper_name_set_by_argument(self): + """Passing name= overwrites __name__ and __qualname__.""" + from orcapod.core.executors.capture_wrapper import make_capture_wrapper + + wrapper = make_capture_wrapper(name="my_transform") + assert wrapper.__name__ == "my_transform" + assert wrapper.__qualname__ == "my_transform" + + def test_named_wrapper_still_captures(self): + """A renamed wrapper must still capture output and return results.""" + from orcapod.core.executors.capture_wrapper import make_capture_wrapper + + wrapper = make_capture_wrapper(name="add_one") + + def add_one(x): + print("hello") + return x + 1 + + raw, stdout, stderr, python_logs = wrapper(add_one, {"x": 5}) + assert raw == 6 + assert "hello" in stdout