diff --git a/patchwork/cli.py b/patchwork/cli.py index 4955c76..8847ca2 100644 --- a/patchwork/cli.py +++ b/patchwork/cli.py @@ -1,10 +1,16 @@ +import argparse import asyncio +import json +import logging +from collections.abc import AsyncIterable from dotenv import load_dotenv +from pydantic_ai import FunctionToolCallEvent, RunContext from rich.console import Console from patchwork.agent import agent from patchwork.deps import PatchworkDeps +from patchwork.logging_config import setup_logging from patchwork.midi import MidiConnection from patchwork.patch_library import PatchLibrary from patchwork.synth_definitions import load_synth_definitions @@ -12,7 +18,28 @@ console = Console() -async def main(): +def _make_event_handler(verbose: bool, logger: logging.Logger): + """Return an event_stream_handler that logs tool calls.""" + + async def handle_events(ctx: RunContext[PatchworkDeps], events: AsyncIterable) -> None: + async for event in events: + if isinstance(event, FunctionToolCallEvent): + tool_name = event.part.tool_name + logger.info("tool call: %s", tool_name) + + if verbose: + try: + args = event.part.args_as_dict() + except Exception: + args = event.part.args + logger.debug("tool args: %s %s", tool_name, json.dumps(args, default=str)) + console.print(f"[dim]⚙ {tool_name}[/dim]") + + return handle_events + + +async def main(verbose: bool = False): + logger = setup_logging(verbose=verbose) midi = MidiConnection() synths = load_synth_definitions() @@ -27,6 +54,8 @@ async def main(): console.print("[dim]no synth definitions found in synths/[/dim]\n") message_history = [] + event_handler = _make_event_handler(verbose, logger) + try: while True: try: @@ -44,7 +73,10 @@ async def main(): try: async with agent.run_stream( - user_input, message_history=message_history, deps=deps + user_input, + message_history=message_history, + deps=deps, + event_stream_handler=event_handler, ) as result: async for chunk in result.stream_text(delta=True): console.print(chunk, end="", markup=False, highlight=False) @@ -52,6 +84,7 @@ async def main(): message_history = result.all_messages() except Exception as e: + logger.exception("Error during agent run") console.print(f"\n[bold red]error:[/bold red] {e}") finally: midi.close() @@ -59,7 +92,12 @@ async def main(): def main_cli(): load_dotenv() - asyncio.run(main()) + + parser = argparse.ArgumentParser(description="patchwork — synth research agent") + parser.add_argument("-v", "--verbose", action="store_true", help="enable verbose logging") + args = parser.parse_args() + + asyncio.run(main(verbose=args.verbose)) if __name__ == "__main__": diff --git a/patchwork/logging_config.py b/patchwork/logging_config.py new file mode 100644 index 0000000..0510676 --- /dev/null +++ b/patchwork/logging_config.py @@ -0,0 +1,25 @@ +import logging +import sys + + +def setup_logging(verbose: bool = False) -> logging.Logger: + """Configure the patchwork logger. Idempotent — safe to call multiple times. + + Always updates the handler level to match the current ``verbose`` setting, + but never adds duplicate handlers. + """ + logger = logging.getLogger("patchwork") + level = logging.DEBUG if verbose else logging.INFO + + if logger.handlers: + logger.handlers[0].setLevel(level) + return logger + + handler = logging.StreamHandler(sys.stderr) + handler.setLevel(level) + handler.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + + logger.setLevel(logging.DEBUG) + logger.addHandler(handler) + + return logger diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py new file mode 100644 index 0000000..08d1931 --- /dev/null +++ b/tests/test_logging_config.py @@ -0,0 +1,44 @@ +import logging + +from patchwork.logging_config import setup_logging + + +class TestSetupLogging: + def setup_method(self): + """Remove any existing patchwork logger handlers before each test.""" + logger = logging.getLogger("patchwork") + logger.handlers.clear() + + def test_returns_logger(self): + logger = setup_logging() + assert isinstance(logger, logging.Logger) + assert logger.name == "patchwork" + + def test_adds_stderr_handler(self): + logger = setup_logging() + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], logging.StreamHandler) + + def test_default_handler_level_is_info(self): + logger = setup_logging(verbose=False) + assert logger.handlers[0].level == logging.INFO + + def test_verbose_handler_level_is_debug(self): + logger = setup_logging(verbose=True) + assert logger.handlers[0].level == logging.DEBUG + + def test_idempotent_no_duplicate_handlers(self): + logger = setup_logging() + setup_logging() + assert len(logger.handlers) == 1 + + def test_logger_level_is_debug(self): + logger = setup_logging() + assert logger.level == logging.DEBUG + + def test_subsequent_call_updates_handler_level(self): + logger = setup_logging(verbose=False) + assert logger.handlers[0].level == logging.INFO + setup_logging(verbose=True) + assert logger.handlers[0].level == logging.DEBUG + assert len(logger.handlers) == 1 diff --git a/tests/test_tool_logging.py b/tests/test_tool_logging.py new file mode 100644 index 0000000..9db451e --- /dev/null +++ b/tests/test_tool_logging.py @@ -0,0 +1,98 @@ +import logging +from unittest.mock import MagicMock + +import pytest + +from patchwork.cli import _make_event_handler + + +def _make_mock_tool_call_event(tool_name: str, args: dict | None = None): + """Create a mock FunctionToolCallEvent.""" + from pydantic_ai import FunctionToolCallEvent + from pydantic_ai.messages import ToolCallPart + + part = ToolCallPart(tool_name=tool_name, args=args or {}) + return FunctionToolCallEvent(part=part) + + +async def _to_async_iterable(items): + for item in items: + yield item + + +@pytest.fixture +def logger(): + log = logging.getLogger("patchwork.test_tool_logging") + log.handlers.clear() + log.setLevel(logging.DEBUG) + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + log.addHandler(handler) + return log + + +class TestToolCallEventHandler: + @pytest.mark.asyncio + async def test_logs_tool_name(self, logger, caplog): + handler = _make_event_handler(verbose=False, logger=logger) + event = _make_mock_tool_call_event("send_cc") + ctx = MagicMock() + + with caplog.at_level(logging.INFO, logger=logger.name): + await handler(ctx, _to_async_iterable([event])) + + assert any("send_cc" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_verbose_logs_args(self, logger, caplog): + handler = _make_event_handler(verbose=True, logger=logger) + args = {"synth_id": "minitaur", "param": "cutoff", "value": 64} + event = _make_mock_tool_call_event("send_cc", args) + ctx = MagicMock() + + with caplog.at_level(logging.DEBUG, logger=logger.name): + await handler(ctx, _to_async_iterable([event])) + + debug_msgs = [r.message for r in caplog.records if r.levelno == logging.DEBUG] + assert any("minitaur" in msg for msg in debug_msgs) + + @pytest.mark.asyncio + async def test_ignores_non_tool_events(self, logger, caplog): + handler = _make_event_handler(verbose=False, logger=logger) + ctx = MagicMock() + + # A non-FunctionToolCallEvent object + other_event = MagicMock() + other_event.event_kind = "part_start" + + with caplog.at_level(logging.INFO, logger=logger.name): + await handler(ctx, _to_async_iterable([other_event])) + + tool_records = [r for r in caplog.records if "tool call" in r.message] + assert len(tool_records) == 0 + + @pytest.mark.asyncio + async def test_handles_multiple_events(self, logger, caplog): + handler = _make_event_handler(verbose=False, logger=logger) + events = [ + _make_mock_tool_call_event("list_synths"), + _make_mock_tool_call_event("send_cc"), + ] + ctx = MagicMock() + + with caplog.at_level(logging.INFO, logger=logger.name): + await handler(ctx, _to_async_iterable(events)) + + tool_records = [r for r in caplog.records if "tool call" in r.message] + assert len(tool_records) == 2 + + @pytest.mark.asyncio + async def test_non_verbose_does_not_print_tool_indicator(self, logger, capsys): + handler = _make_event_handler(verbose=False, logger=logger) + event = _make_mock_tool_call_event("send_cc") + ctx = MagicMock() + + await handler(ctx, _to_async_iterable([event])) + + captured = capsys.readouterr() + assert "⚙" not in captured.out