Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import logging
import os
import sys
import uuid
from pathlib import Path
from collections.abc import Awaitable
from typing import Any, Dict, Optional

from constants import MAX_IMPORT_RECOVERY_ATTEMPTS
from logger import setup_logging
from logger import reset_request_id, set_request_id, setup_logging
from unpack_volume import maybe_unpack
from version import format_version_banner

Expand All @@ -23,6 +25,25 @@
logger.info(format_version_banner())


def _extract_request_id(event: Dict[str, Any]) -> str:
"""Extract RunPod job id from event, with safe fallback."""
event_id = event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id

job_id = event.get("job_id")
if isinstance(job_id, str) and job_id.strip():
return job_id

job = event.get("job")
if isinstance(job, dict):
nested_job_id = job.get("id")
if isinstance(nested_job_id, str) and nested_job_id.strip():
return nested_job_id

return str(uuid.uuid4())


def _is_deployed_mode() -> bool:
"""True when running as a Flash-deployed endpoint (not Live Serverless)."""
return bool(os.getenv("FLASH_RESOURCE_NAME"))
Expand Down Expand Up @@ -215,14 +236,32 @@ def _load_generated_handler() -> Optional[Any]:
# Deployed mode: generated handler is mandatory, failures are fatal.
# Live Serverless mode: FunctionRequest handler is the only path.
if _is_deployed_mode():
handler = _load_generated_handler()
_generated = _load_generated_handler()
if _generated is None:
raise RuntimeError(
"FLASH_RESOURCE_NAME is set but no generated handler could be loaded. "
"Ensure the deployed artifact includes handler_<resource_name>.py and redeploy with 'flash deploy'."
)
generated_handler = _generated

async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
request_id_token = set_request_id(_extract_request_id(event))
try:
result = generated_handler(event)
if isinstance(result, Awaitable):
return await result
return result
finally:
reset_request_id(request_id_token)

else:
from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse
from remote_executor import RemoteExecutor

async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
"""RunPod serverless handler for Live Serverless (FunctionRequest protocol)."""
output: FunctionResponse
request_id_token = set_request_id(_extract_request_id(event))

try:
executor = RemoteExecutor()
Expand All @@ -234,6 +273,8 @@ async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
success=False,
error=f"Error in handler: {str(error)}",
)
finally:
reset_request_id(request_id_token)

return output.model_dump() # type: ignore[no-any-return]

Expand Down
3 changes: 2 additions & 1 deletion src/log_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import deque
from typing import Optional, Deque, Callable

from logger import get_log_format
from logger import ensure_request_id_filter, get_log_format


class LogStreamer:
Expand Down Expand Up @@ -58,6 +58,7 @@ def start_streaming(
# Use same format as main logging
formatter = logging.Formatter(get_log_format(level))
self._handler.setFormatter(formatter)
ensure_request_id_filter(self._handler)

# Add to root logger
root_logger = logging.getLogger()
Expand Down
74 changes: 63 additions & 11 deletions src/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,47 @@
import logging
import os
import sys
from contextvars import ContextVar, Token
from typing import Union, Optional


_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-")


class RequestIdFilter(logging.Filter):
"""Inject request_id from context into each log record."""

def filter(self, record: logging.LogRecord) -> bool:
record.request_id = _REQUEST_ID.get()
return True


_REQUEST_ID_FILTER = RequestIdFilter()


def set_request_id(request_id: Optional[str]) -> Token[str]:
"""Set request id in log context and return reset token."""
if request_id:
normalized = request_id.strip() or "-"
else:
normalized = "-"
return _REQUEST_ID.set(normalized)


def reset_request_id(token: Token[str]) -> None:
"""Reset request id context with token from set_request_id."""
_REQUEST_ID.reset(token)


def get_request_id() -> str:
return _REQUEST_ID.get()


def ensure_request_id_filter(handler: logging.Handler) -> None:
if not any(isinstance(existing, RequestIdFilter) for existing in handler.filters):
handler.addFilter(_REQUEST_ID_FILTER)


def get_log_level() -> int:
"""Get log level from environment variable, defaulting to INFO."""
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
Expand All @@ -19,9 +57,12 @@ def get_log_level() -> int:
def get_log_format(level: int) -> str:
"""Get appropriate log format based on level, matching runpod-flash style."""
if level == logging.DEBUG:
return "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s"
return (
"%(asctime)s | %(levelname)-5s | %(request_id)s | "
"%(name)s | %(filename)s:%(lineno)d | %(message)s"
)
else:
return "%(asctime)s | %(levelname)-5s | %(message)s"
return "%(asctime)s | %(levelname)-5s | %(request_id)s | %(message)s"


def setup_logging(
Expand All @@ -38,25 +79,36 @@ def setup_logging(
stream: Output stream for logs
fmt: Custom format string (auto-selected based on level if None)
"""
# Determine log level
if level is None:
level = get_log_level()
resolved_level = get_log_level()
elif isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
resolved_level = getattr(logging, level.upper(), logging.INFO)
else:
resolved_level = level

# Determine format based on requested level
if fmt is None:
fmt = get_log_format(level)
fmt = get_log_format(resolved_level)

# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.setLevel(resolved_level)

if not root_logger.hasHandlers():
handler = logging.StreamHandler(stream)
handler.setFormatter(logging.Formatter(fmt))
ensure_request_id_filter(handler)
root_logger.addHandler(handler)

# When DEBUG is requested, silence the noisy module
if level == logging.DEBUG:
for handler in root_logger.handlers:
ensure_request_id_filter(handler)

current_formatter = handler.formatter
if current_formatter is None:
handler.setFormatter(logging.Formatter(fmt))
continue

current_format = getattr(current_formatter, "_fmt", "")
if "%(request_id)s" not in current_format:
handler.setFormatter(logging.Formatter(fmt))

if resolved_level == logging.DEBUG:
logging.getLogger("filelock").setLevel(logging.INFO)
21 changes: 20 additions & 1 deletion tests/unit/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
os.environ.pop("FLASH_RESOURCE_NAME", None)
sys.modules.pop("handler", None)

from handler import handler, _load_generated_handler # noqa: E402
from handler import handler, _extract_request_id, _load_generated_handler # noqa: E402
from runpod_flash.protos.remote_execution import FunctionResponse # noqa: E402


Expand Down Expand Up @@ -154,6 +154,25 @@ async def test_handler_class_execution(self):
assert "instance_info" in result


class TestExtractRequestId:
"""Test cases for _extract_request_id event parsing."""

def test_prefers_top_level_id(self):
event = {"id": "job-top", "job_id": "job-fallback", "job": {"id": "job-nested"}}

assert _extract_request_id(event) == "job-top"

def test_uses_job_id_when_id_missing(self):
event = {"job_id": "job-secondary", "job": {"id": "job-nested"}}

assert _extract_request_id(event) == "job-secondary"

def test_uses_nested_job_id_when_top_level_ids_missing(self):
event = {"job": {"id": "job-nested"}}

assert _extract_request_id(event) == "job-nested"


class TestLoadGeneratedHandler:
"""Test cases for _load_generated_handler delegation logic."""

Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging

from logger import (
RequestIdFilter,
ensure_request_id_filter,
get_log_format,
reset_request_id,
set_request_id,
)


def test_log_format_includes_request_id_for_info():
fmt = get_log_format(logging.INFO)
assert "%(request_id)s" in fmt


def test_log_format_includes_request_id_for_debug():
fmt = get_log_format(logging.DEBUG)
assert "%(request_id)s" in fmt


def test_request_id_filter_injects_context_value():
log_record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=1,
msg="hello",
args=(),
exc_info=None,
)
token = set_request_id("job-abc")

try:
request_id_filter = RequestIdFilter()
assert request_id_filter.filter(log_record) is True
assert log_record.request_id == "job-abc"
finally:
reset_request_id(token)


def test_ensure_request_id_filter_attaches_only_once():
handler = logging.StreamHandler()

ensure_request_id_filter(handler)
ensure_request_id_filter(handler)

request_id_filters = [f for f in handler.filters if isinstance(f, RequestIdFilter)]
assert len(request_id_filters) == 1
Loading