From 5446fb5b6b086ed7249b9ade1b6bd4c9683266b3 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Mon, 16 Mar 2026 11:06:38 -0400 Subject: [PATCH 1/4] feat: add request id to worker logs --- src/handler.py | 37 ++++++++++++++++++++-- src/log_streamer.py | 3 +- src/logger.py | 65 ++++++++++++++++++++++++++++++++------- tests/unit/test_logger.py | 49 +++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_logger.py diff --git a/src/handler.py b/src/handler.py index a1245e7..506de7d 100644 --- a/src/handler.py +++ b/src/handler.py @@ -2,11 +2,13 @@ import logging import os import sys +import uuid from pathlib import Path +from collections.abc import Awaitable from typing import Any, Dict, Optional from constants import MAX_IMPORT_RECOVERY_ATTEMPTS -from logger import setup_logging +from logger import setup_logging, set_request_id, reset_request_id from unpack_volume import maybe_unpack from version import format_version_banner @@ -22,6 +24,23 @@ # Log after unpack so bundled runpod_flash is on sys.path logger.info(format_version_banner()) +def _extract_request_id(event: Dict[str, Any]) -> str: + """Extract RunPod job id from event, with safe fallback.""" + event_id = event.get("id") + if isinstance(event_id, str) and event_id.strip(): + return event_id + + job_id = event.get("job_id") + if isinstance(job_id, str) and job_id.strip(): + return job_id + + job = event.get("job") + if isinstance(job, dict): + nested_job_id = job.get("id") + if isinstance(nested_job_id, str) and nested_job_id.strip(): + return nested_job_id + + return str(uuid.uuid4()) def _is_deployed_mode() -> bool: """True when running as a Flash-deployed endpoint (not Live Serverless).""" @@ -215,7 +234,18 @@ def _load_generated_handler() -> Optional[Any]: # Deployed mode: generated handler is mandatory, failures are fatal. # Live Serverless mode: FunctionRequest handler is the only path. if _is_deployed_mode(): - handler = _load_generated_handler() + _generated = _load_generated_handler() + + async def handler(event: Dict[str, Any]) -> Dict[str, Any]: + request_id_token = set_request_id(_extract_request_id(event)) + try: + result = _generated(event) + if isinstance(result, Awaitable): + return await result + return result + finally: + reset_request_id(request_id_token) + else: from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse from remote_executor import RemoteExecutor @@ -223,6 +253,7 @@ def _load_generated_handler() -> Optional[Any]: async def handler(event: Dict[str, Any]) -> Dict[str, Any]: """RunPod serverless handler for Live Serverless (FunctionRequest protocol).""" output: FunctionResponse + request_id_token = set_request_id(_extract_request_id(event)) try: executor = RemoteExecutor() @@ -234,6 +265,8 @@ async def handler(event: Dict[str, Any]) -> Dict[str, Any]: success=False, error=f"Error in handler: {str(error)}", ) + finally: + reset_request_id(request_id_token) return output.model_dump() # type: ignore[no-any-return] diff --git a/src/log_streamer.py b/src/log_streamer.py index 1ec61b8..80c1fe3 100644 --- a/src/log_streamer.py +++ b/src/log_streamer.py @@ -11,7 +11,7 @@ from collections import deque from typing import Optional, Deque, Callable -from logger import get_log_format +from logger import ensure_request_id_filter, get_log_format class LogStreamer: @@ -58,6 +58,7 @@ def start_streaming( # Use same format as main logging formatter = logging.Formatter(get_log_format(level)) self._handler.setFormatter(formatter) + ensure_request_id_filter(self._handler) # Add to root logger root_logger = logging.getLogger() diff --git a/src/logger.py b/src/logger.py index 9f97000..e939f3b 100644 --- a/src/logger.py +++ b/src/logger.py @@ -7,9 +7,47 @@ import logging import os import sys +from contextvars import ContextVar, Token from typing import Union, Optional +_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-") + + +class RequestIdFilter(logging.Filter): + """Inject request_id from context into each log record.""" + + def filter(self, record: logging.LogRecord) -> bool: + record.request_id = _REQUEST_ID.get() + return True + + +_REQUEST_ID_FILTER = RequestIdFilter() + + +def set_request_id(request_id: Optional[str]) -> Token[str]: + """Set request id in log context and return reset token.""" + if request_id: + normalized = request_id.strip() or "-" + else: + normalized = "-" + return _REQUEST_ID.set(normalized) + + +def reset_request_id(token: Token[str]) -> None: + """Reset request id context with token from set_request_id.""" + _REQUEST_ID.reset(token) + + +def get_request_id() -> str: + return _REQUEST_ID.get() + + +def ensure_request_id_filter(handler: logging.Handler) -> None: + if not any(isinstance(existing, RequestIdFilter) for existing in handler.filters): + handler.addFilter(_REQUEST_ID_FILTER) + + def get_log_level() -> int: """Get log level from environment variable, defaulting to INFO.""" log_level = os.environ.get("LOG_LEVEL", "INFO").upper() @@ -19,9 +57,12 @@ def get_log_level() -> int: def get_log_format(level: int) -> str: """Get appropriate log format based on level, matching runpod-flash style.""" if level == logging.DEBUG: - return "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s" + return ( + "%(asctime)s | %(levelname)-5s | %(request_id)s | " + "%(name)s | %(filename)s:%(lineno)d | %(message)s" + ) else: - return "%(asctime)s | %(levelname)-5s | %(message)s" + return "%(asctime)s | %(levelname)-5s | %(request_id)s | %(message)s" def setup_logging( @@ -38,25 +79,27 @@ def setup_logging( stream: Output stream for logs fmt: Custom format string (auto-selected based on level if None) """ - # Determine log level if level is None: - level = get_log_level() + resolved_level = get_log_level() elif isinstance(level, str): - level = getattr(logging, level.upper(), logging.INFO) + resolved_level = getattr(logging, level.upper(), logging.INFO) + else: + resolved_level = level - # Determine format based on requested level if fmt is None: - fmt = get_log_format(level) + fmt = get_log_format(resolved_level) - # Configure root logger root_logger = logging.getLogger() - root_logger.setLevel(level) + root_logger.setLevel(resolved_level) if not root_logger.hasHandlers(): handler = logging.StreamHandler(stream) handler.setFormatter(logging.Formatter(fmt)) + ensure_request_id_filter(handler) root_logger.addHandler(handler) - # When DEBUG is requested, silence the noisy module - if level == logging.DEBUG: + for handler in root_logger.handlers: + ensure_request_id_filter(handler) + + if resolved_level == logging.DEBUG: logging.getLogger("filelock").setLevel(logging.INFO) diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..05fa2fc --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,49 @@ +import logging + +from logger import ( + RequestIdFilter, + ensure_request_id_filter, + get_log_format, + reset_request_id, + set_request_id, +) + + +def test_log_format_includes_request_id_for_info(): + fmt = get_log_format(logging.INFO) + assert "%(request_id)s" in fmt + + +def test_log_format_includes_request_id_for_debug(): + fmt = get_log_format(logging.DEBUG) + assert "%(request_id)s" in fmt + + +def test_request_id_filter_injects_context_value(): + log_record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="hello", + args=(), + exc_info=None, + ) + token = set_request_id("job-abc") + + try: + request_id_filter = RequestIdFilter() + assert request_id_filter.filter(log_record) is True + assert log_record.request_id == "job-abc" + finally: + reset_request_id(token) + + +def test_ensure_request_id_filter_attaches_only_once(): + handler = logging.StreamHandler() + + ensure_request_id_filter(handler) + ensure_request_id_filter(handler) + + request_id_filters = [f for f in handler.filters if isinstance(f, RequestIdFilter)] + assert len(request_id_filters) == 1 From 85db2ddd5eb063eb70b7feb0918a698e253a23a4 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Mon, 16 Mar 2026 11:17:58 -0400 Subject: [PATCH 2/4] chore: add formatter if root logger already exists --- src/logger.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/logger.py b/src/logger.py index e939f3b..9ab4f89 100644 --- a/src/logger.py +++ b/src/logger.py @@ -101,5 +101,14 @@ def setup_logging( for handler in root_logger.handlers: ensure_request_id_filter(handler) + current_formatter = handler.formatter + if current_formatter is None: + handler.setFormatter(logging.Formatter(fmt)) + continue + + current_format = getattr(current_formatter, "_fmt", "") + if "%(request_id)s" not in current_format: + handler.setFormatter(logging.Formatter(fmt)) + if resolved_level == logging.DEBUG: logging.getLogger("filelock").setLevel(logging.INFO) From 01c2b335a96e82a7690d520dc66333454e5abb63 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Mon, 16 Mar 2026 11:22:28 -0400 Subject: [PATCH 3/4] chore: linting --- src/handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/handler.py b/src/handler.py index 506de7d..60cd92f 100644 --- a/src/handler.py +++ b/src/handler.py @@ -24,6 +24,7 @@ # Log after unpack so bundled runpod_flash is on sys.path logger.info(format_version_banner()) + def _extract_request_id(event: Dict[str, Any]) -> str: """Extract RunPod job id from event, with safe fallback.""" event_id = event.get("id") @@ -42,6 +43,7 @@ def _extract_request_id(event: Dict[str, Any]) -> str: return str(uuid.uuid4()) + def _is_deployed_mode() -> bool: """True when running as a Flash-deployed endpoint (not Live Serverless).""" return bool(os.getenv("FLASH_RESOURCE_NAME")) From 16d1759028d59e2e2b0274b40d2a4515fddc7223 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Mon, 16 Mar 2026 13:12:23 -0400 Subject: [PATCH 4/4] chore: pr feedback --- src/handler.py | 10 ++++++++-- tests/unit/test_handler.py | 21 ++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/handler.py b/src/handler.py index 60cd92f..098d896 100644 --- a/src/handler.py +++ b/src/handler.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Optional from constants import MAX_IMPORT_RECOVERY_ATTEMPTS -from logger import setup_logging, set_request_id, reset_request_id +from logger import reset_request_id, set_request_id, setup_logging from unpack_volume import maybe_unpack from version import format_version_banner @@ -237,11 +237,17 @@ def _load_generated_handler() -> Optional[Any]: # Live Serverless mode: FunctionRequest handler is the only path. if _is_deployed_mode(): _generated = _load_generated_handler() + if _generated is None: + raise RuntimeError( + "FLASH_RESOURCE_NAME is set but no generated handler could be loaded. " + "Ensure the deployed artifact includes handler_.py and redeploy with 'flash deploy'." + ) + generated_handler = _generated async def handler(event: Dict[str, Any]) -> Dict[str, Any]: request_id_token = set_request_id(_extract_request_id(event)) try: - result = _generated(event) + result = generated_handler(event) if isinstance(result, Awaitable): return await result return result diff --git a/tests/unit/test_handler.py b/tests/unit/test_handler.py index 4562d55..355de32 100644 --- a/tests/unit/test_handler.py +++ b/tests/unit/test_handler.py @@ -13,7 +13,7 @@ os.environ.pop("FLASH_RESOURCE_NAME", None) sys.modules.pop("handler", None) -from handler import handler, _load_generated_handler # noqa: E402 +from handler import handler, _extract_request_id, _load_generated_handler # noqa: E402 from runpod_flash.protos.remote_execution import FunctionResponse # noqa: E402 @@ -154,6 +154,25 @@ async def test_handler_class_execution(self): assert "instance_info" in result +class TestExtractRequestId: + """Test cases for _extract_request_id event parsing.""" + + def test_prefers_top_level_id(self): + event = {"id": "job-top", "job_id": "job-fallback", "job": {"id": "job-nested"}} + + assert _extract_request_id(event) == "job-top" + + def test_uses_job_id_when_id_missing(self): + event = {"job_id": "job-secondary", "job": {"id": "job-nested"}} + + assert _extract_request_id(event) == "job-secondary" + + def test_uses_nested_job_id_when_top_level_ids_missing(self): + event = {"job": {"id": "job-nested"}} + + assert _extract_request_id(event) == "job-nested" + + class TestLoadGeneratedHandler: """Test cases for _load_generated_handler delegation logic."""