From f2658114dda64c3a37266564488c87cfc41ad0ce Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Thu, 26 Mar 2026 13:08:05 +0100 Subject: [PATCH 1/2] Added e2e tests for tool choices in responses endpoint --- src/app/endpoints/responses.py | 4 +- src/app/main.py | 2 +- src/utils/responses.py | 110 +++- tests/e2e/features/responses.feature | 146 ++++++ .../e2e/features/steps/llm_query_response.py | 88 ++++ tests/unit/utils/test_responses.py | 491 +++++++++++++++++- 6 files changed, 816 insertions(+), 25 deletions(-) diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 1f8bb2be4..e5fc455ed 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -234,7 +234,7 @@ async def responses_endpoint_handler( request.headers, ) - # Build RAG context from Inline RAG sources + #Build RAG context from Inline RAG sources inline_rag_context = await build_rag_context( client, moderation_result.decision, @@ -242,6 +242,7 @@ async def responses_endpoint_handler( vector_store_ids, responses_request.solr, ) + if moderation_result.decision == "passed": responses_request.input = append_inline_rag_context_to_responses_input( responses_request.input, inline_rag_context.context_text @@ -662,6 +663,7 @@ async def handle_non_streaming_response( ) else: try: + print("API Params: ", api_params.model_dump(exclude_none=True)) api_response = cast( OpenAIResponseObject, await client.responses.create( diff --git a/src/app/main.py b/src/app/main.py index d3e0b0c18..76eefc964 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -229,5 +229,5 @@ async def send_wrapper(message: Message) -> None: # RestApiMetricsMiddleware (registered last) is outermost. This ensures metrics # always observe a status code — including 500s synthesised by the exception # middleware — rather than seeing a raw exception with no response. -app.add_middleware(GlobalExceptionMiddleware) +#app.add_middleware(GlobalExceptionMiddleware) app.add_middleware(RestApiMetricsMiddleware) diff --git a/src/utils/responses.py b/src/utils/responses.py index 3e68a26db..7c0357c68 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -26,6 +26,9 @@ from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoice as ToolChoice, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceAllowedTools as AllowedTools, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoiceMode as ToolChoiceMode, ) @@ -417,6 +420,55 @@ def extract_vector_store_ids_from_tools( return vector_store_ids +def _tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: + """Return True if the tool satisfies every key in the allowlist entry. + + ``OpenAIResponseInputToolChoiceAllowedTools.tools`` entries use string keys + and values (e.g. ``type``, ``server_label``, ``name``); each must match the + corresponding attribute on the tool. + + Parameters: + tool: A configured input tool. + entry: One allowlist entry from ``allowed_tools.tools``. + + Returns: + True if all entry keys match the tool. + """ + for key, value in entry.items(): + if not hasattr(tool, key): + return False + attr = getattr(tool, key) + if attr is None: + return False + if attr != value and str(attr) != value: + return False + return True + + +def filter_tools_by_allowed_entries( + tools: list[InputTool], + allowed_entries: list[dict[str, str]], +) -> list[InputTool]: + """Keep tools that match at least one allowlist entry. + + If ``allowed_entries`` is empty, no tools are kept (strict allowlist). + + Parameters: + tools: Tools to filter (typically after translation / preparation). + allowed_entries: Entries from ``OpenAIResponseInputToolChoiceAllowedTools.tools``. + + Returns: + A sublist of ``tools`` matching the allowlist. + """ + if not allowed_entries: + return [] + return [ + t + for t in tools + if any(_tool_matches_allowed_entry(t, e) for e in allowed_entries) + ] + + def resolve_vector_store_ids( vector_store_ids: list[str], byok_rags: list[ByokRag] ) -> list[str]: @@ -1332,10 +1384,19 @@ async def resolve_tool_choice( ) -> tuple[Optional[list[InputTool]], Optional[ToolChoice], Optional[list[str]]]: """Resolve tools and tool_choice for the Responses API. - If the request includes tools, uses them as-is and derives vector_store_ids - from tool configs; otherwise loads tools via prepare_tools (using all - configured vector stores) and honors tool_choice "none" via the no_tools - flag. When no tools end up configured, tool_choice is cleared to None. + If ``tool_choice`` is ``none``, always returns ``(None, None, None)`` — no + tools are sent to Llama Stack, even when the request included explicit + ``tools`` (e.g. file_search). + + If ``tool_choice`` is ``allowed_tools``, it is rewritten for downstream + services: tools are filtered to those matching the allowlist entries, and + ``tool_choice`` becomes ``auto`` or ``required`` per the allowlist ``mode``. + + If the request includes tools and tool_choice is not ``none``, uses them + (after allowlist filtering) and derives vector_store_ids from the prepared + tools; otherwise loads tools via prepare_tools (using all configured vector + stores), then applies allowlist filtering when present. When no tools end + up configured, tool_choice is cleared to None. Args: tools: Tools from the request, or None to use LCORE-configured tools. @@ -1349,35 +1410,46 @@ async def resolve_tool_choice( prepared_tools is the list of tools to use, or None if none configured; prepared_tool_choice is the resolved tool choice, or None when there are no tools; vector_store_ids is extracted from tools (in user-facing format) - when provided, otherwise None. + when provided, otherwise None (also None when tool_choice is ``none``). """ + if isinstance(tool_choice, ToolChoiceMode) and tool_choice == ToolChoiceMode.none: + return None, None, None + + allowed_filters: Optional[list[dict[str, str]]] = None + if isinstance(tool_choice, AllowedTools): + allowed_filters = tool_choice.tools + tool_choice = ToolChoiceMode(tool_choice.mode) + prepared_tools: Optional[list[InputTool]] = None - client = AsyncLlamaStackClientHolder().get_client() if tools: # explicitly specified in request - # Per-request override of vector stores (user-facing rag_ids) - vector_store_ids = extract_vector_store_ids_from_tools(tools) - # Translate user-facing rag_ids to llama-stack vector_store_ids in each file_search tool byok_rags = configuration.configuration.byok_rag prepared_tools = translate_tools_vector_store_ids(tools, byok_rags) + if allowed_filters is not None: + prepared_tools = filter_tools_by_allowed_entries( + prepared_tools, allowed_filters + ) + if not prepared_tools: + return None, None, None + vector_store_ids_list = extract_vector_store_ids_from_tools(prepared_tools) + vector_store_ids = vector_store_ids_list if vector_store_ids_list else None prepared_tool_choice = tool_choice or ToolChoiceMode.auto else: - # Vector stores were not overwritten in request, use all configured vector stores vector_store_ids = None - # Get all tools configured in LCORE (returns None or non-empty list) - no_tools = ( - isinstance(tool_choice, ToolChoiceMode) - and tool_choice == ToolChoiceMode.none - ) - # Vector stores are prepared in llama-stack format + client = AsyncLlamaStackClientHolder().get_client() prepared_tools = await prepare_tools( client=client, - vector_store_ids=vector_store_ids, # allow all configured vector stores - no_tools=no_tools, + vector_store_ids=vector_store_ids, + no_tools=False, token=token, mcp_headers=mcp_headers, request_headers=request_headers, ) - # If there are no tools, tool_choice cannot be set at all - LLS implicit behavior + if allowed_filters is not None and prepared_tools: + prepared_tools = filter_tools_by_allowed_entries( + prepared_tools, allowed_filters + ) + if not prepared_tools: + prepared_tools = None prepared_tool_choice = tool_choice if prepared_tools else None return prepared_tools, prepared_tool_choice, vector_store_ids diff --git a/tests/e2e/features/responses.feature b/tests/e2e/features/responses.feature index 463fe2b9c..da23dce42 100644 --- a/tests/e2e/features/responses.feature +++ b/tests/e2e/features/responses.feature @@ -427,3 +427,149 @@ Feature: Responses endpoint API tests """ Then The status code of the response is 503 And The body of the response contains Unable to connect to Llama Stack + + +Scenario: Check if responses endpoint with tool_choice none answers knowledge question without file search usage + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": "none" + } + """ + Then The status code of the response is 200 + And The responses output should not include any tool invocation item types + And The token metrics should have increased + + Scenario: Check if responses endpoint with tool_choice auto answers a knowledge question using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": "auto" + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with tool_choice required still invokes document search for a basic question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "Hello World!", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "tool_choice": "required" + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The token metrics should have increased + + Scenario: Check if responses endpoint with file search as the chosen tool answers using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": {"type": "file_search"} + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with allowed tools in automatic mode answers knowledge question using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": { + "type": "allowed_tools", + "mode": "auto", + "tools": [{"type": "file_search"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with allowed tools in required mode invokes file search for a basic question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "Hello world!", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "tool_choice": { + "type": "allowed_tools", + "mode": "required", + "tools": [{"type": "file_search"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The token metrics should have increased + + Scenario: Allowed tools auto mode with only MCP in allowlist does not use file search for knowledge question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. Answer in lowercase.", + "tool_choice": { + "type": "allowed_tools", + "mode": "auto", + "tools": [{"type": "mcp"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should not include an item with type "file_search_call" + And The token metrics should have increased \ No newline at end of file diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 01f90a148..7ccd7f245 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -2,6 +2,7 @@ import json import os +from typing import Any, cast import requests from behave import step, then # pyright: ignore[reportAttributeAccessIssue] @@ -12,6 +13,73 @@ # Longer timeout for Prow/OpenShift with CPU-based vLLM DEFAULT_LLM_TIMEOUT = 180 if os.getenv("RUNNING_PROW") else 60 +# Responses API ``output`` item types that indicate tool listing or invocation. +_RESPONSE_TOOL_OUTPUT_ITEM_TYPES = frozenset( + { + "file_search_call", + "mcp_call", + "mcp_list_tools", + "function_call", + "web_search_call", + } +) + + +def _collect_output_item_types(response_body: dict[str, Any]) -> list[str]: + """Collect ``type`` from each top-level ``output`` item in a Responses API JSON body.""" + output = cast(list[dict[str, Any]], response_body["output"]) + return [item["type"] for item in output] + + +@then("The responses output should not include any tool invocation item types") +def responses_output_should_not_include_tool_items(context: Context) -> None: + """Assert no tool-related items appear in the Responses JSON ``output`` array.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + bad = [t for t in types_found if t in _RESPONSE_TOOL_OUTPUT_ITEM_TYPES] + assert not bad, ( + "Expected no tool-related output items, but found types " + f"{bad!r} among all output types {types_found!r}" + ) + + +@then('The responses output should include an item with type "{item_type}"') +def responses_output_should_include_item_type(context: Context, item_type: str) -> None: + """Assert at least one ``output`` item has the given ``type``.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert item_type in types_found, ( + f"Expected output item type {item_type!r} not found; " + f"had types {types_found!r}" + ) + + +@then('The responses output should not include an item with type "{item_type}"') +def responses_output_should_not_include_item_type(context: Context, item_type: str) -> None: + """Assert no ``output`` item has the given ``type``.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert item_type not in types_found, ( + f"Expected output item type {item_type!r} to be absent; " + f"but found types {types_found!r}" + ) + + +@then("The responses output should include an item with one of these types") +def responses_output_should_include_one_of_types(context: Context) -> None: + """Assert at least one output item type matches a row in the scenario table.""" + assert context.response is not None, "Request needs to be performed first" + assert context.table is not None, "Table with column 'item type' is required" + allowed = [row["item type"].strip() for row in context.table] + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert any( + a in types_found for a in allowed + ), f"Expected at least one of {allowed!r} in output types {types_found!r}" + @step("I wait for the response to be completed") def wait_for_complete_response(context: Context) -> None: @@ -161,6 +229,26 @@ def check_referenced_documents_present(context: Context) -> None: assert ( len(response_json["referenced_documents"]) > 0 ), "referenced_documents is empty — no documents were referenced" +@then("The responses output_text should contain following fragments") +def check_fragments_in_responses_output_text(context: Context) -> None: + """Check that fragments from the scenario table appear in JSON ``output_text``. + + Used for POST ``/v1/responses`` (query endpoint uses the ``response`` field). + """ + assert context.response is not None, "Request needs to be performed first" + response_json = context.response.json() + assert ( + "output_text" in response_json + ), f"Expected 'output_text' in JSON body, got keys: {list(response_json.keys())}" + output_text = response_json["output_text"] + + assert context.table is not None, "Fragments are not specified in table" + + for fragment in context.table: + expected = fragment["Fragments in LLM response"] + assert ( + expected in output_text + ), f"Fragment '{expected}' not found in output_text: '{output_text}'" @then("The response should contain following fragments") diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 442dd123e..59fe8d4c3 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -4,16 +4,31 @@ import json from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from fastapi import HTTPException +from llama_stack_api.openai_responses import ( + OpenAIResponseInputTool as InputTool, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceAllowedTools, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceMode as ToolChoiceMode, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolFileSearch as InputToolFileSearch, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFunction as InputToolFunction, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolMCP as InputToolMCP, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolWebSearch as InputToolWebSearch, +) from llama_stack_api.openai_responses import ( OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, ) @@ -54,6 +69,7 @@ extract_text_from_response_items, extract_token_usage, extract_vector_store_ids_from_tools, + filter_tools_by_allowed_entries, get_mcp_tools, get_rag_tools, get_topic_summary, @@ -62,6 +78,7 @@ parse_referenced_documents, prepare_responses_params, prepare_tools, + resolve_tool_choice, resolve_vector_store_ids, ) @@ -861,6 +878,472 @@ async def test_get_topic_summary_api_error(self, mocker: MockerFixture) -> None: await get_topic_summary("test question", mock_client, "model1") +class TestResolveToolChoice: + """Tests for resolve_tool_choice (ToolChoiceMode, AllowedTools, explicit/implicit tools).""" + + @staticmethod + def _passthrough_translate(mocker: MockerFixture) -> None: + mocker.patch( + "utils.responses.translate_tools_vector_store_ids", + side_effect=lambda t, _: t, + ) + + @pytest.mark.asyncio + async def test_tool_choice_none_clears_explicit_tools( + self, mocker: MockerFixture + ) -> None: + """ToolChoiceMode.none yields no tools even when the request listed tools.""" + mock_get_client = mocker.patch( + "utils.responses.AsyncLlamaStackClientHolder.get_client" + ) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared_tools, prepared_choice, vector_store_ids = await resolve_tool_choice( + tools, + ToolChoiceMode.none, + "token", + ) + assert prepared_tools is None + assert prepared_choice is None + assert vector_store_ids is None + mock_get_client.assert_not_called() + + @pytest.mark.asyncio + async def test_tool_choice_none_without_explicit_tools_skips_client( + self, mocker: MockerFixture + ) -> None: + """ToolChoiceMode.none returns early without calling prepare_tools path.""" + mock_get_client = mocker.patch( + "utils.responses.AsyncLlamaStackClientHolder.get_client" + ) + mock_prepare = mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + ) + out = await resolve_tool_choice(None, ToolChoiceMode.none, "token") + assert out == (None, None, None) + mock_get_client.assert_not_called() + mock_prepare.assert_not_called() + + @pytest.mark.asyncio + async def test_explicit_tools_tool_choice_auto(self, mocker: MockerFixture) -> None: + """Explicit tools with ToolChoiceMode.auto pass through after BYOK translation.""" + self._passthrough_translate(mocker) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, choice, vs_ids = await resolve_tool_choice( + tools, + ToolChoiceMode.auto, + "token", + ) + assert prepared is not None and prepared[0].type == "file_search" + assert choice == ToolChoiceMode.auto + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_explicit_tools_tool_choice_required( + self, mocker: MockerFixture + ) -> None: + """Explicit tools with ToolChoiceMode.required are preserved.""" + self._passthrough_translate(mocker) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, choice, _vs = await resolve_tool_choice( + tools, + ToolChoiceMode.required, + "token", + ) + assert prepared is not None + assert choice == ToolChoiceMode.required + + @pytest.mark.asyncio + async def test_explicit_tools_omitted_tool_choice_defaults_to_auto( + self, mocker: MockerFixture + ) -> None: + """When tool_choice is None and tools are explicit, default is auto.""" + self._passthrough_translate(mocker) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, choice, _vs = await resolve_tool_choice(tools, None, "token") + assert prepared is not None + assert choice == ToolChoiceMode.auto + + @pytest.mark.asyncio + async def test_allowed_tools_required_filters_explicit_to_file_search( + self, mocker: MockerFixture + ) -> None: + """AllowedTools mode=required filters to allowlist; vector_store_ids from result.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "file_search"}], + ) + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["vs1"]), + InputToolMCP(server_label="s1", server_url="http://example.com"), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None + assert len(prepared) == 1 + assert prepared[0].type == "file_search" + assert choice == ToolChoiceMode.required + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_allowed_tools_auto_explicit_same_filter( + self, mocker: MockerFixture + ) -> None: + """AllowedTools mode=auto maps to ToolChoiceMode.auto after filtering.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "keep"}], + ) + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="keep", server_url="http://a"), + InputToolMCP(server_label="drop", server_url="http://b"), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert prepared[0].type == "mcp" + assert getattr(prepared[0], "server_label") == "keep" + assert choice == ToolChoiceMode.auto + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_no_match_returns_none_tuple( + self, mocker: MockerFixture + ) -> None: + """When allowlist excludes all tools, return (None, None, None).""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "other"}], + ) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + assert await resolve_tool_choice(tools, allowed, "token") == (None, None, None) + + @pytest.mark.asyncio + async def test_allowed_tools_multiple_allowlist_entries_or_semantics( + self, mocker: MockerFixture + ) -> None: + """Multiple allowlist rows match as OR: keep tools matching any entry.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[ + {"type": "file_search"}, + {"type": "mcp", "server_label": "b"}, + ], + ) + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["vs"]), + InputToolMCP(server_label="a", server_url="http://a"), + InputToolMCP(server_label="b", server_url="http://b"), + ], + ) + prepared, _choice, _vs = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None + assert len(prepared) == 2 + types = {t.type for t in prepared} + assert types == {"file_search", "mcp"} + + @pytest.mark.asyncio + async def test_allowed_tools_function_tool_filtered_by_type_and_name( + self, mocker: MockerFixture + ) -> None: + """Allowlist can target function tools by type and name.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "function", "name": "keep_fn"}], + ) + tools = cast( + list[InputTool], + [ + InputToolFunction(name="keep_fn", parameters={}), + InputToolFunction(name="drop_fn", parameters={}), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert getattr(prepared[0], "name") == "keep_fn" + assert choice == ToolChoiceMode.required + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_web_search_must_match_type_literal( + self, mocker: MockerFixture + ) -> None: + """Web search variants only match if allowlist type matches tool.type exactly.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "web_search"}], + ) + tools = cast( + list[InputTool], + [ + InputToolWebSearch(type="web_search"), + InputToolWebSearch(type="web_search_preview"), + ], + ) + prepared, choice, _vs = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert prepared[0].type == "web_search" + assert choice is not None and choice == ToolChoiceMode.auto + + @pytest.mark.asyncio + async def test_implicit_tool_choice_auto_calls_prepare_tools( + self, mocker: MockerFixture + ) -> None: + """No explicit tools: ToolChoiceMode.auto uses prepare_tools output.""" + fs = InputToolFileSearch(vector_store_ids=["vs1"]) + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[fs], + ) + mock_get_client = mocker.patch( + "utils.responses.AsyncLlamaStackClientHolder.get_client" + ) + prepared, choice, vs_ids = await resolve_tool_choice( + None, + ToolChoiceMode.auto, + "token", + ) + assert prepared == [fs] + assert choice == ToolChoiceMode.auto + assert vs_ids is None + mock_get_client.assert_called_once() + + @pytest.mark.asyncio + async def test_implicit_tool_choice_required_calls_prepare_tools( + self, mocker: MockerFixture + ) -> None: + """No explicit tools: ToolChoiceMode.required passes through when tools exist.""" + fs = InputToolFileSearch(vector_store_ids=["vs1"]) + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[fs], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + prepared, choice, _vs = await resolve_tool_choice( + None, + ToolChoiceMode.required, + "token", + ) + assert prepared == [fs] + assert choice == ToolChoiceMode.required + + @pytest.mark.asyncio + async def test_implicit_prepare_tools_returns_none_clears_tool_choice( + self, mocker: MockerFixture + ) -> None: + """When prepare_tools returns None, tool_choice is cleared.""" + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=None, + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + prepared, choice, vs_ids = await resolve_tool_choice( + None, + ToolChoiceMode.auto, + "token", + ) + assert prepared is None + assert choice is None + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_applies_after_prepare_tools( + self, mocker: MockerFixture + ) -> None: + """No explicit tools: AllowedTools filters prepare_tools output.""" + fs = InputToolFileSearch(vector_store_ids=["vs1"]) + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[fs, mcp], + ) + mock_get_client = mocker.patch( + "utils.responses.AsyncLlamaStackClientHolder.get_client" + ) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "s1"}], + ) + prepared, choice, vs_ids = await resolve_tool_choice( + None, + allowed, + "token", + ) + assert prepared is not None + assert len(prepared) == 1 + assert prepared[0].type == "mcp" + assert choice == ToolChoiceMode.auto + assert vs_ids is None + mock_get_client.assert_called_once() + + @pytest.mark.asyncio + async def test_allowed_tools_implicit_required_mode_after_prepare( + self, mocker: MockerFixture + ) -> None: + """AllowedTools with mode=required after implicit prepare_tools.""" + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[mcp], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "mcp"}], + ) + _prepared, choice, _vs = await resolve_tool_choice(None, allowed, "token") + assert choice == ToolChoiceMode.required + + +class TestFilterToolsByAllowedEntries: + """Tests for filter_tools_by_allowed_entries (per-type matching).""" + + def test_file_search_type_only_keeps_all_file_search(self) -> None: + """One entry ``{type: file_search}`` keeps every file_search tool.""" + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["a"]), + InputToolFileSearch(vector_store_ids=["b"]), + ], + ) + out = filter_tools_by_allowed_entries(tools, [{"type": "file_search"}]) + assert len(out) == 2 + + def test_empty_allowlist_keeps_nothing(self) -> None: + """Empty allowlist removes every tool.""" + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["a"])], + ) + assert filter_tools_by_allowed_entries(tools, []) == [] + + def test_mcp_type_only_matches_all_mcp_tools(self) -> None: + """``{type: mcp}`` keeps every MCP tool regardless of server_label.""" + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="a", server_url="http://a"), + InputToolMCP(server_label="b", server_url="http://b"), + ], + ) + out = filter_tools_by_allowed_entries(tools, [{"type": "mcp"}]) + assert len(out) == 2 + + def test_mcp_type_and_server_label_specific(self) -> None: + """Restrict to one MCP server using type + server_label.""" + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="keep", server_url="http://a"), + InputToolMCP(server_label="other", server_url="http://b"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "keep"}], + ) + assert len(out) == 1 + assert getattr(out[0], "server_label") == "keep" + + def test_function_type_and_name(self) -> None: + """Function tools match on type and name.""" + tools = cast( + list[InputTool], + [ + InputToolFunction(name="fn_a", parameters={}), + InputToolFunction(name="fn_b", parameters={}), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "function", "name": "fn_b"}], + ) + assert len(out) == 1 + assert getattr(out[0], "name") == "fn_b" + + def test_web_search_type_literal_must_match(self) -> None: + """web_search vs web_search_preview require distinct allowlist entries.""" + tools = cast( + list[InputTool], + [ + InputToolWebSearch(type="web_search"), + InputToolWebSearch(type="web_search_preview"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "web_search_preview"}], + ) + assert len(out) == 1 + assert out[0].type == "web_search_preview" + + def test_multiple_allowlist_entries_or_semantics(self) -> None: + """A tool is kept if it matches any allowlist entry.""" + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["x"]), + InputToolMCP(server_label="m", server_url="http://m"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [ + {"type": "function", "name": "nope"}, + {"type": "mcp", "server_label": "m"}, + ], + ) + assert len(out) == 1 + assert out[0].type == "mcp" + + def test_no_entry_matches_returns_empty(self) -> None: + """When no tool satisfies any entry, result is empty.""" + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["a"])], + ) + assert ( + filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "only"}], + ) + == [] + ) + + class TestPrepareTools: """Tests for prepare_tools function.""" @@ -1404,9 +1887,9 @@ async def test_prepare_responses_params_includes_mcp_provider_data_headers( # The result should contain extra_headers with x-llamastack-provider-data dumped = result.model_dump() - assert ( - dumped["extra_headers"] is not None - ), "extra_headers should not be None when MCP tools have headers" + assert dumped["extra_headers"] is not None, ( + "extra_headers should not be None when MCP tools have headers" + ) assert "x-llamastack-provider-data" in dumped["extra_headers"] provider_data = json.loads( From 0b1f3bacaf7dcbdda908184f16ea9fe6c529a1a2 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Thu, 26 Mar 2026 13:08:05 +0100 Subject: [PATCH 2/2] Adjust tools resolution --- docs/responses.md | 27 +- src/app/endpoints/responses.py | 4 +- src/app/main.py | 2 +- src/utils/responses.py | 188 +++++++--- .../e2e/features/steps/llm_query_response.py | 4 +- tests/unit/utils/test_responses.py | 350 ++++++++++++------ 6 files changed, 400 insertions(+), 175 deletions(-) diff --git a/docs/responses.md b/docs/responses.md index 153c939f9..9877921c6 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -280,9 +280,9 @@ Optional. **Tool selection strategy** that controls whether and how the model us **Specific tool objects (object with `type`):** -- `allowed_tools`: Restrict to a list of tool definitions; `mode` is `"auto"` or `"required"`, `tools` is a list of tool objects (same shapes as in [tools](#tools)). -- `file_search`: Force the model to use file search. -- `web_search`: Force the model to use web search (optionally with a variant such as `web_search_preview`). +- `allowed_tools`: Restrict to a list of tool definitions; `mode` is `"auto"` or `"required"`, `tools` is a list of key-valued filters for tools configured by `tools` attribute. +- `file_search`: Force the model to use file-only search. +- `web_search`: Force the model to use only web search. - `function`: Force a specific function; `name` (required) is the function name. - `mcp`: Force a tool on an MCP server; `server_label` (required), `name` (optional) tool name. - `custom`: Force a custom tool; `name` (required). @@ -297,7 +297,15 @@ Simple modes (string): use one of `"auto"`, `"required"`, or `"none"`. { "tool_choice": "none" } ``` -Restrict to specific tools with `allowed_tools` (mode `"auto"` or `"required"`, plus `tools` array): +Restrict tool usage to a specific subset using `allowed_tools`. You can control behavior with the `mode` field (`"auto"` or `"required"`) and explicitly list permitted tools in the `tools` array. + +The `tools` array acts as a **key-value filter**: each object specifies matching criteria (such as `type`, `server_label`, or `name`), and only tools that satisfy all provided attributes are allowed. + +The example below limits tool usage to: +- the `file_search` tool +- a specific MCP tools (`tool_1` and `tool_2`) available on `server_1` (for multiple `name`s act as union) + +If the `name` field is omitted for an MCP tool, the filter applies to all tools available on the specified server. ```json { @@ -305,8 +313,9 @@ Restrict to specific tools with `allowed_tools` (mode `"auto"` or `"required"`, "type": "allowed_tools", "mode": "required", "tools": [ - { "type": "file_search", "vector_store_ids": ["vs_123"] }, - { "type": "web_search" } + { "type": "file_search"}, + { "type": "mcp", "server_label": "server_1", "name": "tool_1" }, + { "type": "mcp", "server_label": "server_1", "name": "tool_2" } ] } } @@ -396,8 +405,8 @@ The following response attributes are inherited directly from the LLS OpenAPI sp | `temperature` | float | Temperature parameter used for generation | | `text` | object | Text response configuration object used | | `top_p` | float | Top-p sampling used | -| `tools` | array[object] | Tools available during generation | -| `tool_choice` | string or object | Tool selection used | +| `tools` | array[object] | Internally resolved tools available during generation | +| `tool_choice` | string | Internally resolved tool choice mode | | `truncation` | string | Truncation strategy applied (`"auto"` or `"disabled"`) | | `usage` | object | Token usage (input_tokens, output_tokens, total_tokens) | | `instructions` | string | System instructions used | @@ -517,6 +526,8 @@ Vector store IDs are configured within the `tools` as `file_search` tools rather **Vector store IDs:** Accepts **LCORE format** in requests and also outputs it in responses; LCORE translates to/from Llama Stack format internally. +The response includes `tools` and `tool_choice` fields that reflect the internally resolved configuration. More specifically, the final set of tools and selection constraints after internal resolution and filtering. + ### LCORE-Specific Extensions The API introduces extensions that are not part of the OpenResponses specification: diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index e5fc455ed..1f8bb2be4 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -234,7 +234,7 @@ async def responses_endpoint_handler( request.headers, ) - #Build RAG context from Inline RAG sources + # Build RAG context from Inline RAG sources inline_rag_context = await build_rag_context( client, moderation_result.decision, @@ -242,7 +242,6 @@ async def responses_endpoint_handler( vector_store_ids, responses_request.solr, ) - if moderation_result.decision == "passed": responses_request.input = append_inline_rag_context_to_responses_input( responses_request.input, inline_rag_context.context_text @@ -663,7 +662,6 @@ async def handle_non_streaming_response( ) else: try: - print("API Params: ", api_params.model_dump(exclude_none=True)) api_response = cast( OpenAIResponseObject, await client.responses.create( diff --git a/src/app/main.py b/src/app/main.py index 76eefc964..d3e0b0c18 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -229,5 +229,5 @@ async def send_wrapper(message: Message) -> None: # RestApiMetricsMiddleware (registered last) is outermost. This ensures metrics # always observe a status code — including 500s synthesised by the exception # middleware — rather than seeing a raw exception with no response. -#app.add_middleware(GlobalExceptionMiddleware) +app.add_middleware(GlobalExceptionMiddleware) app.add_middleware(RestApiMetricsMiddleware) diff --git a/src/utils/responses.py b/src/utils/responses.py index 7c0357c68..1e6175134 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -420,16 +420,12 @@ def extract_vector_store_ids_from_tools( return vector_store_ids -def _tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: +def tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: """Return True if the tool satisfies every key in the allowlist entry. - ``OpenAIResponseInputToolChoiceAllowedTools.tools`` entries use string keys - and values (e.g. ``type``, ``server_label``, ``name``); each must match the - corresponding attribute on the tool. - Parameters: tool: A configured input tool. - entry: One allowlist entry from ``allowed_tools.tools``. + entry: One allowlist entry from allowed_tools.tools. Returns: True if all entry keys match the tool. @@ -445,28 +441,139 @@ def _tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: return True -def filter_tools_by_allowed_entries( - tools: list[InputTool], +def group_mcp_tools_by_server( + entries: list[dict[str, str]], +) -> dict[str, Optional[list[str]]]: + """Group MCP tool filters by server_label. + + Rules: + - Non-MCP entries are ignored. + - Entries without server_label are ignored. + - If any entry for a server has no "name", that server is unrestricted (None). + - Otherwise, collect unique tool names in first-seen order. + + Returns: + Dict mapping: + server_label -> None (unrestricted) OR list of allowed tool names + """ + unrestricted_servers: set[str] = set() + server_to_names: dict[str, list[str]] = {} + for entry in entries: + if entry.get("type") != "mcp": + continue + server = entry.get("server_label") + if not server: + continue + # Unrestricted entry (no "name") + if "name" not in entry: + unrestricted_servers.add(server) + continue + # Skip collecting names if already unrestricted + if server in unrestricted_servers: + continue + name = entry["name"] + if server not in server_to_names: + server_to_names[server] = [] + + if name not in server_to_names[server]: + server_to_names[server].append(name) + + # Build final result + result: dict[str, Optional[list[str]]] = {} + for server in unrestricted_servers: + result[server] = None + + for server, names in server_to_names.items(): + if server not in unrestricted_servers: + result[server] = names + + return result + + +def mcp_strip_name_from_allowlist_entries( allowed_entries: list[dict[str, str]], -) -> list[InputTool]: - """Keep tools that match at least one allowlist entry. +) -> list[dict[str, str]]: + """Return a copy of entries where 'name' is removed only for MCP entries.""" + result: list[dict[str, str]] = [] + for entry in allowed_entries: + new_entry = entry.copy() + if new_entry.get("type") == "mcp": + new_entry.pop("name", None) - If ``allowed_entries`` is empty, no tools are kept (strict allowlist). + result.append(new_entry) + + return result - Parameters: - tools: Tools to filter (typically after translation / preparation). - allowed_entries: Entries from ``OpenAIResponseInputToolChoiceAllowedTools.tools``. + +def mcp_project_allowed_tools_to_names( + tool: InputToolMCP, names: list[str] +) -> list[str] | None: + """Intersect narrowed names with what the MCP tool already permits. Returns: - A sublist of ``tools`` matching the allowlist. + List of permitted tool names, or None if the intersection is empty. + """ + if not names: + return None + name_set = set(names) + allowed = tool.allowed_tools + if allowed is None: + permitted = name_set + elif isinstance(allowed, list): + permitted = name_set & set(allowed) + else: + if allowed.tool_names is None: + permitted = name_set + else: + permitted = name_set & set(allowed.tool_names) + + if not permitted: + return None + + return list(permitted) + + +def filter_tools_by_allowed_entries( + tools: list[InputTool], + allowed_entries: list[dict[str, str]], +) -> list[InputTool]: + """Filter tools based on allowlist entries. + + - Keeps tools matching at least one entry. + - Applies MCP name narrowing when applicable. """ if not allowed_entries: return [] - return [ - t - for t in tools - if any(_tool_matches_allowed_entry(t, e) for e in allowed_entries) - ] + + mcp_names_by_server = group_mcp_tools_by_server(allowed_entries) + sanitized_entries = mcp_strip_name_from_allowlist_entries(allowed_entries) + filtered: list[InputTool] = [] + for tool in tools: + # Skip tools not matching any allowlist entry + if not any(tool_matches_allowed_entry(tool, e) for e in sanitized_entries): + continue + # Non-MCP tools pass through and are handled separately + if tool.type != "mcp": + filtered.append(tool) + continue + + mcp_tool = cast(InputToolMCP, tool) + server = mcp_tool.server_label + + narrowed_names = mcp_names_by_server.get(server) + # No filters specified for this MCP server + if narrowed_names is None: + filtered.append(tool) + continue + + # Apply intersection + permitted = mcp_project_allowed_tools_to_names(mcp_tool, narrowed_names) + if permitted is None: + continue + + filtered.append(mcp_tool.model_copy(update={"allowed_tools": permitted})) + + return filtered def resolve_vector_store_ids( @@ -1382,46 +1489,41 @@ async def resolve_tool_choice( mcp_headers: Optional[McpHeaders] = None, request_headers: Optional[Mapping[str, str]] = None, ) -> tuple[Optional[list[InputTool]], Optional[ToolChoice], Optional[list[str]]]: - """Resolve tools and tool_choice for the Responses API. + """Resolve tools and tool choice for the Responses API. - If ``tool_choice`` is ``none``, always returns ``(None, None, None)`` — no - tools are sent to Llama Stack, even when the request included explicit - ``tools`` (e.g. file_search). + When tool choice disables tools, always return Nones so Llama Stack + sees no tools, even if the request listed tools. - If ``tool_choice`` is ``allowed_tools``, it is rewritten for downstream - services: tools are filtered to those matching the allowlist entries, and - ``tool_choice`` becomes ``auto`` or ``required`` per the allowlist ``mode``. + Allowed-tools mode: filter tools to the allowlist and narrow tool choice to + auto or required from the allowlist mode. - If the request includes tools and tool_choice is not ``none``, uses them - (after allowlist filtering) and derives vector_store_ids from the prepared - tools; otherwise loads tools via prepare_tools (using all configured vector - stores), then applies allowlist filtering when present. When no tools end - up configured, tool_choice is cleared to None. + Otherwise: use request tools (with filtering) and derive vector store IDs, or + load tools via prepare_tools, then filter. Clear tool choice when no tools + remain. Args: - tools: Tools from the request, or None to use LCORE-configured tools. - tool_choice: Requested tool choice (e.g. auto, required, none) or None. - token: User token for MCP/auth. - mcp_headers: Optional MCP headers to propagate. - request_headers: Optional request headers for tool resolution. + tools: Request tools, or None for LCORE-configured tools. + tool_choice: Requested strategy, or None. + token: User token for MCP and auth. + mcp_headers: Optional MCP headers. + request_headers: Optional headers for tool resolution. Returns: - A tuple of (prepared_tools, prepared_tool_choice, vector_store_ids): - prepared_tools is the list of tools to use, or None if none configured; - prepared_tool_choice is the resolved tool choice, or None when there - are no tools; vector_store_ids is extracted from tools (in user-facing format) - when provided, otherwise None (also None when tool_choice is ``none``). + Prepared tools, resolved tool choice, and vector store IDs (user-facing), + each possibly None. """ + # If tool_choice is "none", no tools are allowed if isinstance(tool_choice, ToolChoiceMode) and tool_choice == ToolChoiceMode.none: return None, None, None + # Extract the allowed filters if specified and overwrite tool choice mode allowed_filters: Optional[list[dict[str, str]]] = None if isinstance(tool_choice, AllowedTools): allowed_filters = tool_choice.tools tool_choice = ToolChoiceMode(tool_choice.mode) prepared_tools: Optional[list[InputTool]] = None - if tools: # explicitly specified in request + if tools is not None: # explicitly specified in request byok_rags = configuration.configuration.byok_rag prepared_tools = translate_tools_vector_store_ids(tools, byok_rags) if allowed_filters is not None: diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 7ccd7f245..6808b18d3 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -57,7 +57,9 @@ def responses_output_should_include_item_type(context: Context, item_type: str) @then('The responses output should not include an item with type "{item_type}"') -def responses_output_should_not_include_item_type(context: Context, item_type: str) -> None: +def responses_output_should_not_include_item_type( + context: Context, item_type: str +) -> None: """Assert no ``output`` item has the given ``type``.""" assert context.response is not None, "Request needs to be performed first" response_json = cast(dict[str, Any], context.response.json()) diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 59fe8d4c3..96b47f5c5 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -8,11 +8,15 @@ import pytest from fastapi import HTTPException +from llama_stack_api.openai_responses import ( + AllowedToolsFilter, + OpenAIResponseInputToolChoiceAllowedTools, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputTool as InputTool, ) from llama_stack_api.openai_responses import ( - OpenAIResponseInputToolChoiceAllowedTools, + OpenAIResponseInputToolChoiceFileSearch as ToolChoiceFileSearch, ) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoiceMode as ToolChoiceMode, @@ -889,59 +893,56 @@ def _passthrough_translate(mocker: MockerFixture) -> None: ) @pytest.mark.asyncio - async def test_tool_choice_none_clears_explicit_tools( - self, mocker: MockerFixture + @pytest.mark.parametrize( + "tools_arg", + [ + cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ), + None, + ], + ) + async def test_tool_choice_none_returns_none_tuple( + self, mocker: MockerFixture, tools_arg: Optional[list[InputTool]] ) -> None: - """ToolChoiceMode.none yields no tools even when the request listed tools.""" - mock_get_client = mocker.patch( - "utils.responses.AsyncLlamaStackClientHolder.get_client" - ) - tools = cast( - list[InputTool], - [InputToolFileSearch(vector_store_ids=["vs1"])], - ) - prepared_tools, prepared_choice, vector_store_ids = await resolve_tool_choice( - tools, + """ToolChoiceMode.none always yields (None, None, None).""" + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + mocker.patch("utils.responses.prepare_tools", new_callable=mocker.AsyncMock) + out = await resolve_tool_choice( + tools_arg, ToolChoiceMode.none, "token", ) - assert prepared_tools is None - assert prepared_choice is None - assert vector_store_ids is None - mock_get_client.assert_not_called() - - @pytest.mark.asyncio - async def test_tool_choice_none_without_explicit_tools_skips_client( - self, mocker: MockerFixture - ) -> None: - """ToolChoiceMode.none returns early without calling prepare_tools path.""" - mock_get_client = mocker.patch( - "utils.responses.AsyncLlamaStackClientHolder.get_client" - ) - mock_prepare = mocker.patch( - "utils.responses.prepare_tools", - new_callable=mocker.AsyncMock, - ) - out = await resolve_tool_choice(None, ToolChoiceMode.none, "token") assert out == (None, None, None) - mock_get_client.assert_not_called() - mock_prepare.assert_not_called() @pytest.mark.asyncio - async def test_explicit_tools_tool_choice_auto(self, mocker: MockerFixture) -> None: - """Explicit tools with ToolChoiceMode.auto pass through after BYOK translation.""" + @pytest.mark.parametrize( + "choice,expected_choice", + [ + (ToolChoiceMode.auto, ToolChoiceMode.auto), + (None, ToolChoiceMode.auto), + ], + ) + async def test_explicit_tools_auto_or_default_auto( + self, + mocker: MockerFixture, + choice: Optional[ToolChoiceMode], + expected_choice: ToolChoiceMode, + ) -> None: + """Explicit tools with auto or omitted tool_choice pass through with vs_ids.""" self._passthrough_translate(mocker) tools = cast( list[InputTool], [InputToolFileSearch(vector_store_ids=["vs1"])], ) - prepared, choice, vs_ids = await resolve_tool_choice( + prepared, resolved_choice, vs_ids = await resolve_tool_choice( tools, - ToolChoiceMode.auto, + choice, "token", ) assert prepared is not None and prepared[0].type == "file_search" - assert choice == ToolChoiceMode.auto + assert resolved_choice == expected_choice assert vs_ids == ["vs1"] @pytest.mark.asyncio @@ -963,18 +964,43 @@ async def test_explicit_tools_tool_choice_required( assert choice == ToolChoiceMode.required @pytest.mark.asyncio - async def test_explicit_tools_omitted_tool_choice_defaults_to_auto( + async def test_tool_choice_object_explicit_tools_pass_through( self, mocker: MockerFixture ) -> None: - """When tool_choice is None and tools are explicit, default is auto.""" + """Object-shaped tool_choice passes through with prepared tools and vector_store_ids when tools are provided.""" self._passthrough_translate(mocker) + tool_choice_obj = ToolChoiceFileSearch() tools = cast( list[InputTool], [InputToolFileSearch(vector_store_ids=["vs1"])], ) - prepared, choice, _vs = await resolve_tool_choice(tools, None, "token") - assert prepared is not None - assert choice == ToolChoiceMode.auto + prepared, choice, vs_ids = await resolve_tool_choice( + tools, + tool_choice_obj, + "token", + ) + assert prepared == tools + assert choice is tool_choice_obj + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_tool_choice_object_implicit_prepare_empty_returns_none_tuple( + self, mocker: MockerFixture + ) -> None: + """Object-shaped tool_choice is cleared when no tools are prepared.""" + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=None, + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + tool_choice_obj = ToolChoiceFileSearch() + prepared, choice, vs_ids = await resolve_tool_choice( + None, + tool_choice_obj, + "token", + ) + assert (prepared, choice, vs_ids) == (None, None, None) @pytest.mark.asyncio async def test_allowed_tools_required_filters_explicit_to_file_search( @@ -1020,7 +1046,7 @@ async def test_allowed_tools_auto_explicit_same_filter( prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") assert prepared is not None and len(prepared) == 1 assert prepared[0].type == "mcp" - assert getattr(prepared[0], "server_label") == "keep" + assert prepared[0].server_label == "keep" assert choice == ToolChoiceMode.auto assert vs_ids is None @@ -1040,33 +1066,6 @@ async def test_allowed_tools_no_match_returns_none_tuple( ) assert await resolve_tool_choice(tools, allowed, "token") == (None, None, None) - @pytest.mark.asyncio - async def test_allowed_tools_multiple_allowlist_entries_or_semantics( - self, mocker: MockerFixture - ) -> None: - """Multiple allowlist rows match as OR: keep tools matching any entry.""" - self._passthrough_translate(mocker) - allowed = OpenAIResponseInputToolChoiceAllowedTools( - mode="auto", - tools=[ - {"type": "file_search"}, - {"type": "mcp", "server_label": "b"}, - ], - ) - tools = cast( - list[InputTool], - [ - InputToolFileSearch(vector_store_ids=["vs"]), - InputToolMCP(server_label="a", server_url="http://a"), - InputToolMCP(server_label="b", server_url="http://b"), - ], - ) - prepared, _choice, _vs = await resolve_tool_choice(tools, allowed, "token") - assert prepared is not None - assert len(prepared) == 2 - types = {t.type for t in prepared} - assert types == {"file_search", "mcp"} - @pytest.mark.asyncio async def test_allowed_tools_function_tool_filtered_by_type_and_name( self, mocker: MockerFixture @@ -1086,7 +1085,7 @@ async def test_allowed_tools_function_tool_filtered_by_type_and_name( ) prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") assert prepared is not None and len(prepared) == 1 - assert getattr(prepared[0], "name") == "keep_fn" + assert prepared[0].name == "keep_fn" assert choice == ToolChoiceMode.required assert vs_ids is None @@ -1110,37 +1109,17 @@ async def test_allowed_tools_web_search_must_match_type_literal( prepared, choice, _vs = await resolve_tool_choice(tools, allowed, "token") assert prepared is not None and len(prepared) == 1 assert prepared[0].type == "web_search" - assert choice is not None and choice == ToolChoiceMode.auto - - @pytest.mark.asyncio - async def test_implicit_tool_choice_auto_calls_prepare_tools( - self, mocker: MockerFixture - ) -> None: - """No explicit tools: ToolChoiceMode.auto uses prepare_tools output.""" - fs = InputToolFileSearch(vector_store_ids=["vs1"]) - mocker.patch( - "utils.responses.prepare_tools", - new_callable=mocker.AsyncMock, - return_value=[fs], - ) - mock_get_client = mocker.patch( - "utils.responses.AsyncLlamaStackClientHolder.get_client" - ) - prepared, choice, vs_ids = await resolve_tool_choice( - None, - ToolChoiceMode.auto, - "token", - ) - assert prepared == [fs] assert choice == ToolChoiceMode.auto - assert vs_ids is None - mock_get_client.assert_called_once() @pytest.mark.asyncio - async def test_implicit_tool_choice_required_calls_prepare_tools( - self, mocker: MockerFixture + @pytest.mark.parametrize( + "mode_choice", + [ToolChoiceMode.auto, ToolChoiceMode.required], + ) + async def test_implicit_tool_choice_uses_prepare_tools( + self, mocker: MockerFixture, mode_choice: ToolChoiceMode ) -> None: - """No explicit tools: ToolChoiceMode.required passes through when tools exist.""" + """No explicit tools: prepared list and mode follow tool_choice when tools exist.""" fs = InputToolFileSearch(vector_store_ids=["vs1"]) mocker.patch( "utils.responses.prepare_tools", @@ -1148,13 +1127,14 @@ async def test_implicit_tool_choice_required_calls_prepare_tools( return_value=[fs], ) mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") - prepared, choice, _vs = await resolve_tool_choice( + prepared, choice, vs_ids = await resolve_tool_choice( None, - ToolChoiceMode.required, + mode_choice, "token", ) assert prepared == [fs] - assert choice == ToolChoiceMode.required + assert choice == mode_choice + assert vs_ids is None @pytest.mark.asyncio async def test_implicit_prepare_tools_returns_none_clears_tool_choice( @@ -1188,9 +1168,7 @@ async def test_allowed_tools_applies_after_prepare_tools( new_callable=mocker.AsyncMock, return_value=[fs, mcp], ) - mock_get_client = mocker.patch( - "utils.responses.AsyncLlamaStackClientHolder.get_client" - ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") allowed = OpenAIResponseInputToolChoiceAllowedTools( mode="auto", tools=[{"type": "mcp", "server_label": "s1"}], @@ -1205,7 +1183,27 @@ async def test_allowed_tools_applies_after_prepare_tools( assert prepared[0].type == "mcp" assert choice == ToolChoiceMode.auto assert vs_ids is None - mock_get_client.assert_called_once() + + @pytest.mark.asyncio + async def test_allowed_tools_implicit_filter_excludes_all_tools( + self, mocker: MockerFixture + ) -> None: + """Implicit tools: allowlist can remove every prepared tool.""" + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[mcp], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "file_search"}], + ) + prepared, choice, vs_ids = await resolve_tool_choice(None, allowed, "token") + assert prepared is None + assert choice is None + assert vs_ids is None @pytest.mark.asyncio async def test_allowed_tools_implicit_required_mode_after_prepare( @@ -1223,8 +1221,10 @@ async def test_allowed_tools_implicit_required_mode_after_prepare( mode="required", tools=[{"type": "mcp"}], ) - _prepared, choice, _vs = await resolve_tool_choice(None, allowed, "token") + prepared, choice, vs_ids = await resolve_tool_choice(None, allowed, "token") + assert prepared == [mcp] assert choice == ToolChoiceMode.required + assert vs_ids is None class TestFilterToolsByAllowedEntries: @@ -1248,7 +1248,7 @@ def test_empty_allowlist_keeps_nothing(self) -> None: list[InputTool], [InputToolFileSearch(vector_store_ids=["a"])], ) - assert filter_tools_by_allowed_entries(tools, []) == [] + assert not filter_tools_by_allowed_entries(tools, []) def test_mcp_type_only_matches_all_mcp_tools(self) -> None: """``{type: mcp}`` keeps every MCP tool regardless of server_label.""" @@ -1276,7 +1276,7 @@ def test_mcp_type_and_server_label_specific(self) -> None: [{"type": "mcp", "server_label": "keep"}], ) assert len(out) == 1 - assert getattr(out[0], "server_label") == "keep" + assert out[0].server_label == "keep" def test_function_type_and_name(self) -> None: """Function tools match on type and name.""" @@ -1292,7 +1292,7 @@ def test_function_type_and_name(self) -> None: [{"type": "function", "name": "fn_b"}], ) assert len(out) == 1 - assert getattr(out[0], "name") == "fn_b" + assert out[0].name == "fn_b" def test_web_search_type_literal_must_match(self) -> None: """web_search vs web_search_preview require distinct allowlist entries.""" @@ -1335,12 +1335,124 @@ def test_no_entry_matches_returns_empty(self) -> None: list[InputTool], [InputToolFileSearch(vector_store_ids=["a"])], ) - assert ( - filter_tools_by_allowed_entries( - tools, - [{"type": "mcp", "server_label": "only"}], - ) - == [] + assert not filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "only"}], + ) + + def test_mcp_name_grouped_by_server_narrows_allowed_tools(self) -> None: + """MCP ``name`` is grouped by ``server_label`` and applied after generic match.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="keep", + server_url="http://a", + allowed_tools=["alpha", "beta"], + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "keep", "name": "alpha"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["alpha"] + + def test_mcp_allowed_tools_none_projects_to_entry_names(self) -> None: + """``allowed_tools`` None permits any name; projection narrows to grouped names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=None, + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "gamma"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["gamma"] + + def test_mcp_server_without_name_in_allowlist_skips_projection(self) -> None: + """Any MCP entry without ``name`` for that server disables name narrowing.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=["a", "b"], + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [ + {"type": "mcp", "server_label": "s"}, + {"type": "mcp", "server_label": "s", "name": "a"}, + ], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["a", "b"] + + def test_mcp_allowed_tools_filter_tool_names_none(self) -> None: + """AllowedToolsFilter with ``tool_names`` None does not block grouped names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=AllowedToolsFilter(tool_names=None), + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "z"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["z"] + + def test_mcp_allowed_tools_filter_intersects_with_grouped_names(self) -> None: + """AllowedToolsFilter with explicit ``tool_names`` intersects grouped allowlist names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=AllowedToolsFilter(tool_names=["alpha", "beta"]), + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "alpha"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["alpha"] + + def test_mcp_name_not_permitted_drops_tool(self) -> None: + """Empty intersection between grouped names and ``allowed_tools`` drops the tool.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=["only_this"], + ), + ], + ) + assert not filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "other"}], ) @@ -1887,9 +1999,9 @@ async def test_prepare_responses_params_includes_mcp_provider_data_headers( # The result should contain extra_headers with x-llamastack-provider-data dumped = result.model_dump() - assert dumped["extra_headers"] is not None, ( - "extra_headers should not be None when MCP tools have headers" - ) + assert ( + dumped["extra_headers"] is not None + ), "extra_headers should not be None when MCP tools have headers" assert "x-llamastack-provider-data" in dumped["extra_headers"] provider_data = json.loads(