diff --git a/docs/responses.md b/docs/responses.md index 153c939f9..9877921c6 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -280,9 +280,9 @@ Optional. **Tool selection strategy** that controls whether and how the model us **Specific tool objects (object with `type`):** -- `allowed_tools`: Restrict to a list of tool definitions; `mode` is `"auto"` or `"required"`, `tools` is a list of tool objects (same shapes as in [tools](#tools)). -- `file_search`: Force the model to use file search. -- `web_search`: Force the model to use web search (optionally with a variant such as `web_search_preview`). +- `allowed_tools`: Restrict to a list of tool definitions; `mode` is `"auto"` or `"required"`, `tools` is a list of key-valued filters for tools configured by `tools` attribute. +- `file_search`: Force the model to use file-only search. +- `web_search`: Force the model to use only web search. - `function`: Force a specific function; `name` (required) is the function name. - `mcp`: Force a tool on an MCP server; `server_label` (required), `name` (optional) tool name. - `custom`: Force a custom tool; `name` (required). @@ -297,7 +297,15 @@ Simple modes (string): use one of `"auto"`, `"required"`, or `"none"`. { "tool_choice": "none" } ``` -Restrict to specific tools with `allowed_tools` (mode `"auto"` or `"required"`, plus `tools` array): +Restrict tool usage to a specific subset using `allowed_tools`. You can control behavior with the `mode` field (`"auto"` or `"required"`) and explicitly list permitted tools in the `tools` array. + +The `tools` array acts as a **key-value filter**: each object specifies matching criteria (such as `type`, `server_label`, or `name`), and only tools that satisfy all provided attributes are allowed. + +The example below limits tool usage to: +- the `file_search` tool +- a specific MCP tools (`tool_1` and `tool_2`) available on `server_1` (for multiple `name`s act as union) + +If the `name` field is omitted for an MCP tool, the filter applies to all tools available on the specified server. ```json { @@ -305,8 +313,9 @@ Restrict to specific tools with `allowed_tools` (mode `"auto"` or `"required"`, "type": "allowed_tools", "mode": "required", "tools": [ - { "type": "file_search", "vector_store_ids": ["vs_123"] }, - { "type": "web_search" } + { "type": "file_search"}, + { "type": "mcp", "server_label": "server_1", "name": "tool_1" }, + { "type": "mcp", "server_label": "server_1", "name": "tool_2" } ] } } @@ -396,8 +405,8 @@ The following response attributes are inherited directly from the LLS OpenAPI sp | `temperature` | float | Temperature parameter used for generation | | `text` | object | Text response configuration object used | | `top_p` | float | Top-p sampling used | -| `tools` | array[object] | Tools available during generation | -| `tool_choice` | string or object | Tool selection used | +| `tools` | array[object] | Internally resolved tools available during generation | +| `tool_choice` | string | Internally resolved tool choice mode | | `truncation` | string | Truncation strategy applied (`"auto"` or `"disabled"`) | | `usage` | object | Token usage (input_tokens, output_tokens, total_tokens) | | `instructions` | string | System instructions used | @@ -517,6 +526,8 @@ Vector store IDs are configured within the `tools` as `file_search` tools rather **Vector store IDs:** Accepts **LCORE format** in requests and also outputs it in responses; LCORE translates to/from Llama Stack format internally. +The response includes `tools` and `tool_choice` fields that reflect the internally resolved configuration. More specifically, the final set of tools and selection constraints after internal resolution and filtering. + ### LCORE-Specific Extensions The API introduces extensions that are not part of the OpenResponses specification: diff --git a/src/utils/responses.py b/src/utils/responses.py index 3e68a26db..1e6175134 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -26,6 +26,9 @@ from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoice as ToolChoice, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceAllowedTools as AllowedTools, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoiceMode as ToolChoiceMode, ) @@ -417,6 +420,162 @@ def extract_vector_store_ids_from_tools( return vector_store_ids +def tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: + """Return True if the tool satisfies every key in the allowlist entry. + + Parameters: + tool: A configured input tool. + entry: One allowlist entry from allowed_tools.tools. + + Returns: + True if all entry keys match the tool. + """ + for key, value in entry.items(): + if not hasattr(tool, key): + return False + attr = getattr(tool, key) + if attr is None: + return False + if attr != value and str(attr) != value: + return False + return True + + +def group_mcp_tools_by_server( + entries: list[dict[str, str]], +) -> dict[str, Optional[list[str]]]: + """Group MCP tool filters by server_label. + + Rules: + - Non-MCP entries are ignored. + - Entries without server_label are ignored. + - If any entry for a server has no "name", that server is unrestricted (None). + - Otherwise, collect unique tool names in first-seen order. + + Returns: + Dict mapping: + server_label -> None (unrestricted) OR list of allowed tool names + """ + unrestricted_servers: set[str] = set() + server_to_names: dict[str, list[str]] = {} + for entry in entries: + if entry.get("type") != "mcp": + continue + server = entry.get("server_label") + if not server: + continue + # Unrestricted entry (no "name") + if "name" not in entry: + unrestricted_servers.add(server) + continue + # Skip collecting names if already unrestricted + if server in unrestricted_servers: + continue + name = entry["name"] + if server not in server_to_names: + server_to_names[server] = [] + + if name not in server_to_names[server]: + server_to_names[server].append(name) + + # Build final result + result: dict[str, Optional[list[str]]] = {} + for server in unrestricted_servers: + result[server] = None + + for server, names in server_to_names.items(): + if server not in unrestricted_servers: + result[server] = names + + return result + + +def mcp_strip_name_from_allowlist_entries( + allowed_entries: list[dict[str, str]], +) -> list[dict[str, str]]: + """Return a copy of entries where 'name' is removed only for MCP entries.""" + result: list[dict[str, str]] = [] + for entry in allowed_entries: + new_entry = entry.copy() + if new_entry.get("type") == "mcp": + new_entry.pop("name", None) + + result.append(new_entry) + + return result + + +def mcp_project_allowed_tools_to_names( + tool: InputToolMCP, names: list[str] +) -> list[str] | None: + """Intersect narrowed names with what the MCP tool already permits. + + Returns: + List of permitted tool names, or None if the intersection is empty. + """ + if not names: + return None + name_set = set(names) + allowed = tool.allowed_tools + if allowed is None: + permitted = name_set + elif isinstance(allowed, list): + permitted = name_set & set(allowed) + else: + if allowed.tool_names is None: + permitted = name_set + else: + permitted = name_set & set(allowed.tool_names) + + if not permitted: + return None + + return list(permitted) + + +def filter_tools_by_allowed_entries( + tools: list[InputTool], + allowed_entries: list[dict[str, str]], +) -> list[InputTool]: + """Filter tools based on allowlist entries. + + - Keeps tools matching at least one entry. + - Applies MCP name narrowing when applicable. + """ + if not allowed_entries: + return [] + + mcp_names_by_server = group_mcp_tools_by_server(allowed_entries) + sanitized_entries = mcp_strip_name_from_allowlist_entries(allowed_entries) + filtered: list[InputTool] = [] + for tool in tools: + # Skip tools not matching any allowlist entry + if not any(tool_matches_allowed_entry(tool, e) for e in sanitized_entries): + continue + # Non-MCP tools pass through and are handled separately + if tool.type != "mcp": + filtered.append(tool) + continue + + mcp_tool = cast(InputToolMCP, tool) + server = mcp_tool.server_label + + narrowed_names = mcp_names_by_server.get(server) + # No filters specified for this MCP server + if narrowed_names is None: + filtered.append(tool) + continue + + # Apply intersection + permitted = mcp_project_allowed_tools_to_names(mcp_tool, narrowed_names) + if permitted is None: + continue + + filtered.append(mcp_tool.model_copy(update={"allowed_tools": permitted})) + + return filtered + + def resolve_vector_store_ids( vector_store_ids: list[str], byok_rags: list[ByokRag] ) -> list[str]: @@ -1330,54 +1489,69 @@ async def resolve_tool_choice( mcp_headers: Optional[McpHeaders] = None, request_headers: Optional[Mapping[str, str]] = None, ) -> tuple[Optional[list[InputTool]], Optional[ToolChoice], Optional[list[str]]]: - """Resolve tools and tool_choice for the Responses API. + """Resolve tools and tool choice for the Responses API. + + When tool choice disables tools, always return Nones so Llama Stack + sees no tools, even if the request listed tools. - If the request includes tools, uses them as-is and derives vector_store_ids - from tool configs; otherwise loads tools via prepare_tools (using all - configured vector stores) and honors tool_choice "none" via the no_tools - flag. When no tools end up configured, tool_choice is cleared to None. + Allowed-tools mode: filter tools to the allowlist and narrow tool choice to + auto or required from the allowlist mode. + + Otherwise: use request tools (with filtering) and derive vector store IDs, or + load tools via prepare_tools, then filter. Clear tool choice when no tools + remain. Args: - tools: Tools from the request, or None to use LCORE-configured tools. - tool_choice: Requested tool choice (e.g. auto, required, none) or None. - token: User token for MCP/auth. - mcp_headers: Optional MCP headers to propagate. - request_headers: Optional request headers for tool resolution. + tools: Request tools, or None for LCORE-configured tools. + tool_choice: Requested strategy, or None. + token: User token for MCP and auth. + mcp_headers: Optional MCP headers. + request_headers: Optional headers for tool resolution. Returns: - A tuple of (prepared_tools, prepared_tool_choice, vector_store_ids): - prepared_tools is the list of tools to use, or None if none configured; - prepared_tool_choice is the resolved tool choice, or None when there - are no tools; vector_store_ids is extracted from tools (in user-facing format) - when provided, otherwise None. + Prepared tools, resolved tool choice, and vector store IDs (user-facing), + each possibly None. """ + # If tool_choice is "none", no tools are allowed + if isinstance(tool_choice, ToolChoiceMode) and tool_choice == ToolChoiceMode.none: + return None, None, None + + # Extract the allowed filters if specified and overwrite tool choice mode + allowed_filters: Optional[list[dict[str, str]]] = None + if isinstance(tool_choice, AllowedTools): + allowed_filters = tool_choice.tools + tool_choice = ToolChoiceMode(tool_choice.mode) + prepared_tools: Optional[list[InputTool]] = None - client = AsyncLlamaStackClientHolder().get_client() - if tools: # explicitly specified in request - # Per-request override of vector stores (user-facing rag_ids) - vector_store_ids = extract_vector_store_ids_from_tools(tools) - # Translate user-facing rag_ids to llama-stack vector_store_ids in each file_search tool + if tools is not None: # explicitly specified in request byok_rags = configuration.configuration.byok_rag prepared_tools = translate_tools_vector_store_ids(tools, byok_rags) + if allowed_filters is not None: + prepared_tools = filter_tools_by_allowed_entries( + prepared_tools, allowed_filters + ) + if not prepared_tools: + return None, None, None + vector_store_ids_list = extract_vector_store_ids_from_tools(prepared_tools) + vector_store_ids = vector_store_ids_list if vector_store_ids_list else None prepared_tool_choice = tool_choice or ToolChoiceMode.auto else: - # Vector stores were not overwritten in request, use all configured vector stores vector_store_ids = None - # Get all tools configured in LCORE (returns None or non-empty list) - no_tools = ( - isinstance(tool_choice, ToolChoiceMode) - and tool_choice == ToolChoiceMode.none - ) - # Vector stores are prepared in llama-stack format + client = AsyncLlamaStackClientHolder().get_client() prepared_tools = await prepare_tools( client=client, - vector_store_ids=vector_store_ids, # allow all configured vector stores - no_tools=no_tools, + vector_store_ids=vector_store_ids, + no_tools=False, token=token, mcp_headers=mcp_headers, request_headers=request_headers, ) - # If there are no tools, tool_choice cannot be set at all - LLS implicit behavior + if allowed_filters is not None and prepared_tools: + prepared_tools = filter_tools_by_allowed_entries( + prepared_tools, allowed_filters + ) + if not prepared_tools: + prepared_tools = None prepared_tool_choice = tool_choice if prepared_tools else None return prepared_tools, prepared_tool_choice, vector_store_ids diff --git a/tests/e2e/features/responses.feature b/tests/e2e/features/responses.feature index 463fe2b9c..da23dce42 100644 --- a/tests/e2e/features/responses.feature +++ b/tests/e2e/features/responses.feature @@ -427,3 +427,149 @@ Feature: Responses endpoint API tests """ Then The status code of the response is 503 And The body of the response contains Unable to connect to Llama Stack + + +Scenario: Check if responses endpoint with tool_choice none answers knowledge question without file search usage + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": "none" + } + """ + Then The status code of the response is 200 + And The responses output should not include any tool invocation item types + And The token metrics should have increased + + Scenario: Check if responses endpoint with tool_choice auto answers a knowledge question using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": "auto" + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with tool_choice required still invokes document search for a basic question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "Hello World!", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "tool_choice": "required" + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The token metrics should have increased + + Scenario: Check if responses endpoint with file search as the chosen tool answers using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": {"type": "file_search"} + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with allowed tools in automatic mode answers knowledge question using file search + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. You MUST use the file_search tool to answer. Answer in lowercase.", + "tool_choice": { + "type": "allowed_tools", + "mode": "auto", + "tools": [{"type": "file_search"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The responses output_text should contain following fragments + | Fragments in LLM response | + | great work | + And The token metrics should have increased + + Scenario: Check if responses endpoint with allowed tools in required mode invokes file search for a basic question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "Hello world!", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "tool_choice": { + "type": "allowed_tools", + "mode": "required", + "tools": [{"type": "file_search"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should include an item with type "file_search_call" + And The token metrics should have increased + + Scenario: Allowed tools auto mode with only MCP in allowlist does not use file search for knowledge question + Given The system is in default state + And I set the Authorization header to Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Ikpva + And I capture the current token metrics + When I use "responses" to ask question with authorization header + """ + { + "input": "What is the title of the article from Paul?", + "model": "{PROVIDER}/{MODEL}", + "stream": false, + "instructions": "You are an assistant. Answer in lowercase.", + "tool_choice": { + "type": "allowed_tools", + "mode": "auto", + "tools": [{"type": "mcp"}] + } + } + """ + Then The status code of the response is 200 + And The responses output should not include an item with type "file_search_call" + And The token metrics should have increased \ No newline at end of file diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 01f90a148..6808b18d3 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -2,6 +2,7 @@ import json import os +from typing import Any, cast import requests from behave import step, then # pyright: ignore[reportAttributeAccessIssue] @@ -12,6 +13,75 @@ # Longer timeout for Prow/OpenShift with CPU-based vLLM DEFAULT_LLM_TIMEOUT = 180 if os.getenv("RUNNING_PROW") else 60 +# Responses API ``output`` item types that indicate tool listing or invocation. +_RESPONSE_TOOL_OUTPUT_ITEM_TYPES = frozenset( + { + "file_search_call", + "mcp_call", + "mcp_list_tools", + "function_call", + "web_search_call", + } +) + + +def _collect_output_item_types(response_body: dict[str, Any]) -> list[str]: + """Collect ``type`` from each top-level ``output`` item in a Responses API JSON body.""" + output = cast(list[dict[str, Any]], response_body["output"]) + return [item["type"] for item in output] + + +@then("The responses output should not include any tool invocation item types") +def responses_output_should_not_include_tool_items(context: Context) -> None: + """Assert no tool-related items appear in the Responses JSON ``output`` array.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + bad = [t for t in types_found if t in _RESPONSE_TOOL_OUTPUT_ITEM_TYPES] + assert not bad, ( + "Expected no tool-related output items, but found types " + f"{bad!r} among all output types {types_found!r}" + ) + + +@then('The responses output should include an item with type "{item_type}"') +def responses_output_should_include_item_type(context: Context, item_type: str) -> None: + """Assert at least one ``output`` item has the given ``type``.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert item_type in types_found, ( + f"Expected output item type {item_type!r} not found; " + f"had types {types_found!r}" + ) + + +@then('The responses output should not include an item with type "{item_type}"') +def responses_output_should_not_include_item_type( + context: Context, item_type: str +) -> None: + """Assert no ``output`` item has the given ``type``.""" + assert context.response is not None, "Request needs to be performed first" + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert item_type not in types_found, ( + f"Expected output item type {item_type!r} to be absent; " + f"but found types {types_found!r}" + ) + + +@then("The responses output should include an item with one of these types") +def responses_output_should_include_one_of_types(context: Context) -> None: + """Assert at least one output item type matches a row in the scenario table.""" + assert context.response is not None, "Request needs to be performed first" + assert context.table is not None, "Table with column 'item type' is required" + allowed = [row["item type"].strip() for row in context.table] + response_json = cast(dict[str, Any], context.response.json()) + types_found = _collect_output_item_types(response_json) + assert any( + a in types_found for a in allowed + ), f"Expected at least one of {allowed!r} in output types {types_found!r}" + @step("I wait for the response to be completed") def wait_for_complete_response(context: Context) -> None: @@ -161,6 +231,26 @@ def check_referenced_documents_present(context: Context) -> None: assert ( len(response_json["referenced_documents"]) > 0 ), "referenced_documents is empty — no documents were referenced" +@then("The responses output_text should contain following fragments") +def check_fragments_in_responses_output_text(context: Context) -> None: + """Check that fragments from the scenario table appear in JSON ``output_text``. + + Used for POST ``/v1/responses`` (query endpoint uses the ``response`` field). + """ + assert context.response is not None, "Request needs to be performed first" + response_json = context.response.json() + assert ( + "output_text" in response_json + ), f"Expected 'output_text' in JSON body, got keys: {list(response_json.keys())}" + output_text = response_json["output_text"] + + assert context.table is not None, "Fragments are not specified in table" + + for fragment in context.table: + expected = fragment["Fragments in LLM response"] + assert ( + expected in output_text + ), f"Fragment '{expected}' not found in output_text: '{output_text}'" @then("The response should contain following fragments") diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 442dd123e..96b47f5c5 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -4,16 +4,35 @@ import json from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from fastapi import HTTPException +from llama_stack_api.openai_responses import ( + AllowedToolsFilter, + OpenAIResponseInputToolChoiceAllowedTools, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputTool as InputTool, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceFileSearch as ToolChoiceFileSearch, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolChoiceMode as ToolChoiceMode, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolFileSearch as InputToolFileSearch, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFunction as InputToolFunction, +) from llama_stack_api.openai_responses import ( OpenAIResponseInputToolMCP as InputToolMCP, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolWebSearch as InputToolWebSearch, +) from llama_stack_api.openai_responses import ( OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, ) @@ -54,6 +73,7 @@ extract_text_from_response_items, extract_token_usage, extract_vector_store_ids_from_tools, + filter_tools_by_allowed_entries, get_mcp_tools, get_rag_tools, get_topic_summary, @@ -62,6 +82,7 @@ parse_referenced_documents, prepare_responses_params, prepare_tools, + resolve_tool_choice, resolve_vector_store_ids, ) @@ -861,6 +882,580 @@ async def test_get_topic_summary_api_error(self, mocker: MockerFixture) -> None: await get_topic_summary("test question", mock_client, "model1") +class TestResolveToolChoice: + """Tests for resolve_tool_choice (ToolChoiceMode, AllowedTools, explicit/implicit tools).""" + + @staticmethod + def _passthrough_translate(mocker: MockerFixture) -> None: + mocker.patch( + "utils.responses.translate_tools_vector_store_ids", + side_effect=lambda t, _: t, + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "tools_arg", + [ + cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ), + None, + ], + ) + async def test_tool_choice_none_returns_none_tuple( + self, mocker: MockerFixture, tools_arg: Optional[list[InputTool]] + ) -> None: + """ToolChoiceMode.none always yields (None, None, None).""" + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + mocker.patch("utils.responses.prepare_tools", new_callable=mocker.AsyncMock) + out = await resolve_tool_choice( + tools_arg, + ToolChoiceMode.none, + "token", + ) + assert out == (None, None, None) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "choice,expected_choice", + [ + (ToolChoiceMode.auto, ToolChoiceMode.auto), + (None, ToolChoiceMode.auto), + ], + ) + async def test_explicit_tools_auto_or_default_auto( + self, + mocker: MockerFixture, + choice: Optional[ToolChoiceMode], + expected_choice: ToolChoiceMode, + ) -> None: + """Explicit tools with auto or omitted tool_choice pass through with vs_ids.""" + self._passthrough_translate(mocker) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, resolved_choice, vs_ids = await resolve_tool_choice( + tools, + choice, + "token", + ) + assert prepared is not None and prepared[0].type == "file_search" + assert resolved_choice == expected_choice + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_explicit_tools_tool_choice_required( + self, mocker: MockerFixture + ) -> None: + """Explicit tools with ToolChoiceMode.required are preserved.""" + self._passthrough_translate(mocker) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, choice, _vs = await resolve_tool_choice( + tools, + ToolChoiceMode.required, + "token", + ) + assert prepared is not None + assert choice == ToolChoiceMode.required + + @pytest.mark.asyncio + async def test_tool_choice_object_explicit_tools_pass_through( + self, mocker: MockerFixture + ) -> None: + """Object-shaped tool_choice passes through with prepared tools and vector_store_ids when tools are provided.""" + self._passthrough_translate(mocker) + tool_choice_obj = ToolChoiceFileSearch() + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + prepared, choice, vs_ids = await resolve_tool_choice( + tools, + tool_choice_obj, + "token", + ) + assert prepared == tools + assert choice is tool_choice_obj + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_tool_choice_object_implicit_prepare_empty_returns_none_tuple( + self, mocker: MockerFixture + ) -> None: + """Object-shaped tool_choice is cleared when no tools are prepared.""" + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=None, + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + tool_choice_obj = ToolChoiceFileSearch() + prepared, choice, vs_ids = await resolve_tool_choice( + None, + tool_choice_obj, + "token", + ) + assert (prepared, choice, vs_ids) == (None, None, None) + + @pytest.mark.asyncio + async def test_allowed_tools_required_filters_explicit_to_file_search( + self, mocker: MockerFixture + ) -> None: + """AllowedTools mode=required filters to allowlist; vector_store_ids from result.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "file_search"}], + ) + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["vs1"]), + InputToolMCP(server_label="s1", server_url="http://example.com"), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None + assert len(prepared) == 1 + assert prepared[0].type == "file_search" + assert choice == ToolChoiceMode.required + assert vs_ids == ["vs1"] + + @pytest.mark.asyncio + async def test_allowed_tools_auto_explicit_same_filter( + self, mocker: MockerFixture + ) -> None: + """AllowedTools mode=auto maps to ToolChoiceMode.auto after filtering.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "keep"}], + ) + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="keep", server_url="http://a"), + InputToolMCP(server_label="drop", server_url="http://b"), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert prepared[0].type == "mcp" + assert prepared[0].server_label == "keep" + assert choice == ToolChoiceMode.auto + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_no_match_returns_none_tuple( + self, mocker: MockerFixture + ) -> None: + """When allowlist excludes all tools, return (None, None, None).""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "other"}], + ) + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["vs1"])], + ) + assert await resolve_tool_choice(tools, allowed, "token") == (None, None, None) + + @pytest.mark.asyncio + async def test_allowed_tools_function_tool_filtered_by_type_and_name( + self, mocker: MockerFixture + ) -> None: + """Allowlist can target function tools by type and name.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "function", "name": "keep_fn"}], + ) + tools = cast( + list[InputTool], + [ + InputToolFunction(name="keep_fn", parameters={}), + InputToolFunction(name="drop_fn", parameters={}), + ], + ) + prepared, choice, vs_ids = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert prepared[0].name == "keep_fn" + assert choice == ToolChoiceMode.required + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_web_search_must_match_type_literal( + self, mocker: MockerFixture + ) -> None: + """Web search variants only match if allowlist type matches tool.type exactly.""" + self._passthrough_translate(mocker) + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "web_search"}], + ) + tools = cast( + list[InputTool], + [ + InputToolWebSearch(type="web_search"), + InputToolWebSearch(type="web_search_preview"), + ], + ) + prepared, choice, _vs = await resolve_tool_choice(tools, allowed, "token") + assert prepared is not None and len(prepared) == 1 + assert prepared[0].type == "web_search" + assert choice == ToolChoiceMode.auto + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mode_choice", + [ToolChoiceMode.auto, ToolChoiceMode.required], + ) + async def test_implicit_tool_choice_uses_prepare_tools( + self, mocker: MockerFixture, mode_choice: ToolChoiceMode + ) -> None: + """No explicit tools: prepared list and mode follow tool_choice when tools exist.""" + fs = InputToolFileSearch(vector_store_ids=["vs1"]) + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[fs], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + prepared, choice, vs_ids = await resolve_tool_choice( + None, + mode_choice, + "token", + ) + assert prepared == [fs] + assert choice == mode_choice + assert vs_ids is None + + @pytest.mark.asyncio + async def test_implicit_prepare_tools_returns_none_clears_tool_choice( + self, mocker: MockerFixture + ) -> None: + """When prepare_tools returns None, tool_choice is cleared.""" + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=None, + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + prepared, choice, vs_ids = await resolve_tool_choice( + None, + ToolChoiceMode.auto, + "token", + ) + assert prepared is None + assert choice is None + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_applies_after_prepare_tools( + self, mocker: MockerFixture + ) -> None: + """No explicit tools: AllowedTools filters prepare_tools output.""" + fs = InputToolFileSearch(vector_store_ids=["vs1"]) + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[fs, mcp], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "mcp", "server_label": "s1"}], + ) + prepared, choice, vs_ids = await resolve_tool_choice( + None, + allowed, + "token", + ) + assert prepared is not None + assert len(prepared) == 1 + assert prepared[0].type == "mcp" + assert choice == ToolChoiceMode.auto + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_implicit_filter_excludes_all_tools( + self, mocker: MockerFixture + ) -> None: + """Implicit tools: allowlist can remove every prepared tool.""" + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[mcp], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="auto", + tools=[{"type": "file_search"}], + ) + prepared, choice, vs_ids = await resolve_tool_choice(None, allowed, "token") + assert prepared is None + assert choice is None + assert vs_ids is None + + @pytest.mark.asyncio + async def test_allowed_tools_implicit_required_mode_after_prepare( + self, mocker: MockerFixture + ) -> None: + """AllowedTools with mode=required after implicit prepare_tools.""" + mcp = InputToolMCP(server_label="s1", server_url="http://x") + mocker.patch( + "utils.responses.prepare_tools", + new_callable=mocker.AsyncMock, + return_value=[mcp], + ) + mocker.patch("utils.responses.AsyncLlamaStackClientHolder.get_client") + allowed = OpenAIResponseInputToolChoiceAllowedTools( + mode="required", + tools=[{"type": "mcp"}], + ) + prepared, choice, vs_ids = await resolve_tool_choice(None, allowed, "token") + assert prepared == [mcp] + assert choice == ToolChoiceMode.required + assert vs_ids is None + + +class TestFilterToolsByAllowedEntries: + """Tests for filter_tools_by_allowed_entries (per-type matching).""" + + def test_file_search_type_only_keeps_all_file_search(self) -> None: + """One entry ``{type: file_search}`` keeps every file_search tool.""" + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["a"]), + InputToolFileSearch(vector_store_ids=["b"]), + ], + ) + out = filter_tools_by_allowed_entries(tools, [{"type": "file_search"}]) + assert len(out) == 2 + + def test_empty_allowlist_keeps_nothing(self) -> None: + """Empty allowlist removes every tool.""" + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["a"])], + ) + assert not filter_tools_by_allowed_entries(tools, []) + + def test_mcp_type_only_matches_all_mcp_tools(self) -> None: + """``{type: mcp}`` keeps every MCP tool regardless of server_label.""" + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="a", server_url="http://a"), + InputToolMCP(server_label="b", server_url="http://b"), + ], + ) + out = filter_tools_by_allowed_entries(tools, [{"type": "mcp"}]) + assert len(out) == 2 + + def test_mcp_type_and_server_label_specific(self) -> None: + """Restrict to one MCP server using type + server_label.""" + tools = cast( + list[InputTool], + [ + InputToolMCP(server_label="keep", server_url="http://a"), + InputToolMCP(server_label="other", server_url="http://b"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "keep"}], + ) + assert len(out) == 1 + assert out[0].server_label == "keep" + + def test_function_type_and_name(self) -> None: + """Function tools match on type and name.""" + tools = cast( + list[InputTool], + [ + InputToolFunction(name="fn_a", parameters={}), + InputToolFunction(name="fn_b", parameters={}), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "function", "name": "fn_b"}], + ) + assert len(out) == 1 + assert out[0].name == "fn_b" + + def test_web_search_type_literal_must_match(self) -> None: + """web_search vs web_search_preview require distinct allowlist entries.""" + tools = cast( + list[InputTool], + [ + InputToolWebSearch(type="web_search"), + InputToolWebSearch(type="web_search_preview"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "web_search_preview"}], + ) + assert len(out) == 1 + assert out[0].type == "web_search_preview" + + def test_multiple_allowlist_entries_or_semantics(self) -> None: + """A tool is kept if it matches any allowlist entry.""" + tools = cast( + list[InputTool], + [ + InputToolFileSearch(vector_store_ids=["x"]), + InputToolMCP(server_label="m", server_url="http://m"), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [ + {"type": "function", "name": "nope"}, + {"type": "mcp", "server_label": "m"}, + ], + ) + assert len(out) == 1 + assert out[0].type == "mcp" + + def test_no_entry_matches_returns_empty(self) -> None: + """When no tool satisfies any entry, result is empty.""" + tools = cast( + list[InputTool], + [InputToolFileSearch(vector_store_ids=["a"])], + ) + assert not filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "only"}], + ) + + def test_mcp_name_grouped_by_server_narrows_allowed_tools(self) -> None: + """MCP ``name`` is grouped by ``server_label`` and applied after generic match.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="keep", + server_url="http://a", + allowed_tools=["alpha", "beta"], + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "keep", "name": "alpha"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["alpha"] + + def test_mcp_allowed_tools_none_projects_to_entry_names(self) -> None: + """``allowed_tools`` None permits any name; projection narrows to grouped names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=None, + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "gamma"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["gamma"] + + def test_mcp_server_without_name_in_allowlist_skips_projection(self) -> None: + """Any MCP entry without ``name`` for that server disables name narrowing.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=["a", "b"], + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [ + {"type": "mcp", "server_label": "s"}, + {"type": "mcp", "server_label": "s", "name": "a"}, + ], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["a", "b"] + + def test_mcp_allowed_tools_filter_tool_names_none(self) -> None: + """AllowedToolsFilter with ``tool_names`` None does not block grouped names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=AllowedToolsFilter(tool_names=None), + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "z"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["z"] + + def test_mcp_allowed_tools_filter_intersects_with_grouped_names(self) -> None: + """AllowedToolsFilter with explicit ``tool_names`` intersects grouped allowlist names.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=AllowedToolsFilter(tool_names=["alpha", "beta"]), + ), + ], + ) + out = filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "alpha"}], + ) + assert len(out) == 1 + assert out[0].allowed_tools == ["alpha"] + + def test_mcp_name_not_permitted_drops_tool(self) -> None: + """Empty intersection between grouped names and ``allowed_tools`` drops the tool.""" + tools = cast( + list[InputTool], + [ + InputToolMCP( + server_label="s", + server_url="http://a", + allowed_tools=["only_this"], + ), + ], + ) + assert not filter_tools_by_allowed_entries( + tools, + [{"type": "mcp", "server_label": "s", "name": "other"}], + ) + + class TestPrepareTools: """Tests for prepare_tools function."""