diff --git a/config/sql-agent.php b/config/sql-agent.php index 40e9946..5561a63 100644 --- a/config/sql-agent.php +++ b/config/sql-agent.php @@ -167,9 +167,20 @@ 'default_limit' => env('SQL_AGENT_DEFAULT_LIMIT', 100), 'chat_history_length' => env('SQL_AGENT_CHAT_HISTORY', 10), - // Custom tool class names resolved from the container, e.g.: - // [App\SqlAgent\MyCustomTool::class] - 'tools' => [], + // Tool class names resolved from the container. Includes all built-in tools + // by default. Remove entries to disable specific tools, or add your own. + 'tools' => array_merge([ + \Knobik\SqlAgent\Tools\RunSqlTool::class, + \Knobik\SqlAgent\Tools\IntrospectSchemaTool::class, + \Knobik\SqlAgent\Tools\SearchKnowledgeTool::class, + \Knobik\SqlAgent\Tools\AskUserTool::class, + ], env('SQL_AGENT_LEARNING_ENABLED', true) ? [ + \Knobik\SqlAgent\Tools\SaveLearningTool::class, + \Knobik\SqlAgent\Tools\SaveQueryTool::class, + ] : []), + + // Timeout in seconds for the ask_user tool to wait for a user reply + 'ask_user_timeout' => env('SQL_AGENT_ASK_USER_TIMEOUT', 300), // MCP server names (from config/relay.php) whose tools should be // available to the agent. Requires prism-php/relay to be installed. diff --git a/docs/src/content/docs/guides/configuration.md b/docs/src/content/docs/guides/configuration.md index bcff4ab..6361311 100644 --- a/docs/src/content/docs/guides/configuration.md +++ b/docs/src/content/docs/guides/configuration.md @@ -220,6 +220,7 @@ Control how the agentic loop operates: 'max_iterations' => env('SQL_AGENT_MAX_ITERATIONS', 10), 'default_limit' => env('SQL_AGENT_DEFAULT_LIMIT', 100), 'chat_history_length' => env('SQL_AGENT_CHAT_HISTORY', 10), + 'ask_user_timeout' => env('SQL_AGENT_ASK_USER_TIMEOUT', 300), ], ``` @@ -228,20 +229,29 @@ Control how the agentic loop operates: | `max_iterations` | Maximum number of tool-calling rounds before the agent stops | `10` | | `default_limit` | `LIMIT` applied to queries that don't specify one | `100` | | `chat_history_length` | Number of previous messages included for conversational context | `10` | +| `ask_user_timeout` | Seconds to wait for a user reply when the `ask_user` tool is invoked | `300` | ### Custom Tools -You can extend the agent with your own tools by listing class names in the `tools` array: +All agent tools — including built-in ones — are registered via the `tools` array. You can add your own tools, remove built-in ones you don't need, or reorder them: ```php 'agent' => [ // ... other options ... - 'tools' => [ - \App\SqlAgent\CurrentDateTimeTool::class, - ], + 'tools' => array_merge([ + \Knobik\SqlAgent\Tools\RunSqlTool::class, + \Knobik\SqlAgent\Tools\IntrospectSchemaTool::class, + \Knobik\SqlAgent\Tools\SearchKnowledgeTool::class, + \Knobik\SqlAgent\Tools\AskUserTool::class, + ], env('SQL_AGENT_LEARNING_ENABLED', true) ? [ + \Knobik\SqlAgent\Tools\SaveLearningTool::class, + \Knobik\SqlAgent\Tools\SaveQueryTool::class, + ] : []), ], ``` +To disable a tool, simply remove its entry from the array. For example, remove `AskUserTool::class` to prevent the LLM from asking clarifying questions. The learning tools (`SaveLearningTool`, `SaveQueryTool`) are automatically included or excluded based on the `SQL_AGENT_LEARNING_ENABLED` environment variable. + Each class must extend `Prism\Prism\Tool` and is resolved from the Laravel container with full dependency injection support. See the [Custom Tools](/sql-agent/guides/custom-tools/) guide for detailed examples and best practices. ### MCP Server Tools (Relay) diff --git a/docs/src/content/docs/reference/tools.md b/docs/src/content/docs/reference/tools.md index a2e9a07..23e5e30 100644 --- a/docs/src/content/docs/reference/tools.md +++ b/docs/src/content/docs/reference/tools.md @@ -470,6 +470,138 @@ If a query pattern with the same question already exists, the tool returns an er --- +## `ask_user` + +Ask the user a clarifying question when their request is ambiguous. The agent pauses and presents a question card in the web UI. The card can include clickable suggestion buttons with optional descriptions, supports multi-select mode, and always includes a free-text input so the user can type a custom answer. Once the user responds, the agent continues with that answer in context. + +**Description sent to LLM:** +> Ask the user a clarifying question when their request is ambiguous. Use this when you need more information before proceeding. You may provide suggested options with optional descriptions. Set multiple=true to let the user pick more than one option. The user can always type a custom free-text response instead of picking a suggestion. + +:::note +Included in the default `agent.tools` config. Remove `AskUserTool::class` from the array to disable it. In non-streaming mode (`run()`), the tool returns a fallback message instructing the LLM to make its best guess. +::: + +### Parameters + +```json +{ + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The clarifying question to ask the user." + }, + "suggestions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "label": { + "type": "string", + "description": "The short label for this suggestion (displayed on the button)." + }, + "description": { + "type": "string", + "description": "Optional longer description explaining this option." + } + }, + "required": ["label"] + }, + "description": "Optional list of suggested answers. Each has a label and optional description." + }, + "multiple": { + "type": "boolean", + "description": "Set to true to allow the user to select multiple suggestions. Defaults to false." + } + }, + "required": ["question"] +} +``` + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `question` | string | Yes | — | The clarifying question to display. | +| `suggestions` | object[] | No | `[]` | Suggested answers, each with a `label` (string, required) and `description` (string, optional). | +| `multiple` | boolean | No | `false` | When `true`, the user can select multiple suggestions before submitting. | + +### Return Value + +The tool returns a plain string to the LLM: + +- `"User answered: "` — when the user clicks a suggestion, submits a multi-selection, or types a custom answer. +- A fallback message when the user doesn't respond in time or the connection is lost. + +For multi-select, the answer is a comma-separated list of selected labels (e.g., `"User answered: Revenue trends, Customer segments"`). + +### How It Works + +1. The LLM calls `ask_user` with a question, optional suggestions, and an optional `multiple` flag. +2. An `ask_user` SSE event is sent to the frontend with the question, suggestions, multiple flag, and a unique request ID. +3. The frontend renders a question card: + - **Single-select** (default): clicking a suggestion immediately submits it. + - **Multi-select** (`multiple: true`): suggestions have checkboxes that toggle on/off. A "Submit selection" button sends all selected labels. + - Suggestions with a `description` show the description text below the label. + - A free-text input is always available at the bottom. +4. The user's answer POSTs to `/ask-user-reply` and writes to the cache. +5. The tool polls the cache, picks up the answer, and returns it to the LLM. +6. The LLM continues with the user's answer in context. + +### Configuration + +The timeout for waiting on a user reply is configurable: + +```php +'agent' => [ + 'ask_user_timeout' => env('SQL_AGENT_ASK_USER_TIMEOUT', 300), // seconds +], +``` + +### Example Tool Calls + +**Simple single-select:** + +```json +{ + "question": "Which time period are you interested in?", + "suggestions": [ + { "label": "Last 7 days" }, + { "label": "Last 30 days" }, + { "label": "Last quarter" }, + { "label": "All time" } + ] +} +``` + +**With descriptions:** + +```json +{ + "question": "What kind of analysis would you like?", + "suggestions": [ + { "label": "Revenue trends", "description": "Monthly revenue breakdown with growth rates" }, + { "label": "Customer segments", "description": "RFM analysis grouping customers by behavior" }, + { "label": "Product performance", "description": "Sales volume and margin by product category" } + ] +} +``` + +**Multi-select:** + +```json +{ + "question": "Which metrics should I include in the report?", + "suggestions": [ + { "label": "Total revenue", "description": "Sum of all completed orders" }, + { "label": "Order count", "description": "Number of orders placed" }, + { "label": "Average order value" }, + { "label": "Customer count" } + ], + "multiple": true +} +``` + +--- + ## Tool Availability Not all tools are available in every configuration: @@ -481,7 +613,8 @@ Not all tools are available in every configuration: | `search_knowledge` | Yes | — | | `save_learning` | No | Requires `sql-agent.learning.enabled = true` | | `save_validated_query` | No | Requires `sql-agent.learning.enabled = true` | +| `ask_user` | Yes | Remove from `agent.tools` to disable | -When learning is disabled (`SQL_AGENT_LEARNING_ENABLED=false`), the `save_learning` and `save_validated_query` tools are not registered with the LLM, and the related instructions are removed from the system prompt. +All tools are registered via the `agent.tools` config array. Remove any entry to disable that tool. When learning is disabled (`SQL_AGENT_LEARNING_ENABLED=false`), the `save_learning` and `save_validated_query` tools are automatically skipped even if present in the array. -In addition to the built-in tools above, you can register your own tools via the `agent.tools` config option. See the [Custom Tools](/sql-agent/guides/custom-tools/) guide for details. +You can also register your own tools by adding class names to the `agent.tools` array. See the [Custom Tools](/sql-agent/guides/custom-tools/) guide for details. diff --git a/resources/prompts/system.blade.php b/resources/prompts/system.blade.php index 0699892..d67dfae 100644 --- a/resources/prompts/system.blade.php +++ b/resources/prompts/system.blade.php @@ -31,31 +31,11 @@ You cannot JOIN across databases. Run separate queries and combine results in your response. ## Available Tools - -### run_sql -Execute a SQL query. Specify the `connection` parameter to choose which database to query. Only {{ implode(' and ', config('sql-agent.sql.allowed_statements', ['SELECT', 'WITH'])) }} statements allowed. - -### introspect_schema -Get detailed schema information about tables, columns, relationships, and data types. Specify the `connection` parameter to choose which database to inspect. - -### search_knowledge -Search for relevant query patterns, learnings, and past discoveries about the database. - -@if(config('sql-agent.learning.enabled', true)) -### save_learning -Save a discovery to the knowledge base (type errors, date formats, column quirks, business logic). - -### save_validated_query -Save a successful query pattern for reuse. Use when a query correctly answers a common question. -@endif -@if(!empty($customTools)) -{{-- Custom tools registered via config('sql-agent.agent.tools') --}} -@foreach($customTools as $tool) +@foreach($tools as $tool) ### {{ $tool->name() }} {{ $tool->description() }} @endforeach -@endif ## Workflow diff --git a/resources/views/components/layouts/app.blade.php b/resources/views/components/layouts/app.blade.php index 6c17bd2..112cc12 100644 --- a/resources/views/components/layouts/app.blade.php +++ b/resources/views/components/layouts/app.blade.php @@ -311,6 +311,22 @@ .markdown-content tool[data-type="default"]::before { background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='%236b7280'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z'/%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M15 12a3 3 0 11-6 0 3 3 0 016 0z'/%3E%3C/svg%3E"); } + .markdown-content tool[data-type="ask"] { + background: #fffbeb; + color: #b45309; + border-color: #fde68a; + } + .markdown-content tool[data-type="ask"]::before { + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='%23b45309'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M8.228 9c.549-1.165 2.03-2 3.772-2 2.21 0 4 1.343 4 3 0 1.4-1.278 2.575-3.006 2.907-.542.104-.994.54-.994 1.093m0 3h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E"); + } + .dark .markdown-content tool[data-type="ask"] { + background: rgba(120, 53, 15, 0.3); + color: #fbbf24; + border-color: rgba(251, 191, 36, 0.3); + } + .dark .markdown-content tool[data-type="ask"]::before { + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='%23fbbf24'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M8.228 9c.549-1.165 2.03-2 3.772-2 2.21 0 4 1.343 4 3 0 1.4-1.278 2.575-3.006 2.907-.542.104-.994.54-.994 1.093m0 3h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E"); + } .dark .markdown-content tool[data-type="sql"]::before { background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='%23fca5a5'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z'/%3E%3C/svg%3E"); } diff --git a/resources/views/livewire/chat-component.blade.php b/resources/views/livewire/chat-component.blade.php index 873c666..44ceb3e 100644 --- a/resources/views/livewire/chat-component.blade.php +++ b/resources/views/livewire/chat-component.blade.php @@ -164,6 +164,91 @@ class="group w-full p-4 text-left bg-white dark:bg-gray-800 hover:bg-gray-50 dar + + {{-- Ask User Question Card --}} + {{-- Input Area --}} @@ -242,6 +327,9 @@ function chatStream() { pendingMessageId: null, conversationId: @json($conversationId), abortController: null, + askUserData: null, + askUserInput: '', + askUserSelected: [], // Show streaming UI while streaming or finishing get showStreamingUI() { @@ -402,6 +490,43 @@ function chatStream() { this.isFinishing = false; this.streamedContent = ''; this.pendingUserMessage = ''; + this.askUserData = null; + this.askUserInput = ''; + this.askUserSelected = []; + }, + + toggleAskUserSelection(label) { + const idx = this.askUserSelected.indexOf(label); + if (idx === -1) { + this.askUserSelected.push(label); + } else { + this.askUserSelected.splice(idx, 1); + } + }, + + async submitAskUserReply(answer) { + if (!this.askUserData || !answer.trim()) return; + + const requestId = this.askUserData.request_id; + this.askUserData = null; + this.askUserInput = ''; + this.askUserSelected = []; + + try { + await fetch('{{ route("sql-agent.ask-user-reply") }}', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRF-TOKEN': '{{ csrf_token() }}', + }, + body: JSON.stringify({ + request_id: requestId, + answer: answer, + }), + }); + } catch (error) { + console.error('Failed to submit ask-user reply:', error); + } }, handleEvent(data) { @@ -417,6 +542,10 @@ function chatStream() { this.streamedContent += data.text; this.renderContent(); this.scrollToBottom(); + } else if (data.request_id !== undefined) { + // Ask-user event + this.askUserData = data; + this.scrollToBottom(); } else if (data.message !== undefined) { // Error event this.streamedContent = 'Error: ' + data.message; diff --git a/routes/web.php b/routes/web.php index 533a903..de296cf 100644 --- a/routes/web.php +++ b/routes/web.php @@ -1,6 +1,7 @@ name('query.execute'); + + // Ask-user reply endpoint + Route::post('/ask-user-reply', AskUserReplyController::class)->name('ask-user-reply'); }); } diff --git a/src/Agent/SqlAgent.php b/src/Agent/SqlAgent.php index 732b160..103d653 100644 --- a/src/Agent/SqlAgent.php +++ b/src/Agent/SqlAgent.php @@ -10,6 +10,7 @@ use Knobik\SqlAgent\Llm\StreamChunk; use Knobik\SqlAgent\Services\ConnectionRegistry; use Knobik\SqlAgent\Services\ContextBuilder; +use Knobik\SqlAgent\Tools\AskUserTool; use Knobik\SqlAgent\Tools\RunSqlTool; use Prism\Prism\Facades\Prism; use Prism\Prism\Streaming\Events\StepFinishEvent; @@ -305,7 +306,7 @@ protected function prepareLoop(string $question, array $history = []): AgentLoop $context = $this->contextBuilder->build($question); $extra = [ - 'customTools' => $this->getCustomTools(), + 'tools' => $this->getRegisteredTools(), 'multiConnection' => true, 'connections' => $this->connectionRegistry->all(), ]; @@ -335,7 +336,7 @@ protected function reset(): void // Reset tool state foreach ($this->toolRegistry->all() as $tool) { - if ($tool instanceof RunSqlTool) { + if ($tool instanceof RunSqlTool || $tool instanceof AskUserTool) { $tool->reset(); } } @@ -358,16 +359,13 @@ protected function prepareTools(?string $question = null): array } /** - * Get custom (non-built-in) tools from the registry. + * Get all registered tools for inclusion in the system prompt. * * @return Tool[] */ - protected function getCustomTools(): array + protected function getRegisteredTools(): array { - return array_values(array_filter( - $this->toolRegistry->all(), - fn (Tool $tool): bool => ! str_starts_with($tool::class, 'Knobik\\SqlAgent\\Tools\\'), - )); + return $this->toolRegistry->all(); } protected function collectToolCalls(): array diff --git a/src/Agent/ToolLabelResolver.php b/src/Agent/ToolLabelResolver.php index 11731a0..1782cff 100644 --- a/src/Agent/ToolLabelResolver.php +++ b/src/Agent/ToolLabelResolver.php @@ -14,6 +14,7 @@ class ToolLabelResolver 'search_knowledge' => 'Searching knowledge base', 'save_learning' => 'Saving learning', 'save_validated_query' => 'Saving query pattern', + 'ask_user' => 'Asking for clarification', ]; protected const TYPES = [ @@ -22,6 +23,7 @@ class ToolLabelResolver 'search_knowledge' => 'search', 'save_learning' => 'save', 'save_validated_query' => 'save', + 'ask_user' => 'ask', ]; public function getLabel(string $toolName): string diff --git a/src/Http/Actions/StreamAgentResponse.php b/src/Http/Actions/StreamAgentResponse.php index c94ea29..db91945 100644 --- a/src/Http/Actions/StreamAgentResponse.php +++ b/src/Http/Actions/StreamAgentResponse.php @@ -5,15 +5,19 @@ namespace Knobik\SqlAgent\Http\Actions; use Illuminate\Support\Facades\Log; +use Illuminate\Support\Str; use Knobik\SqlAgent\Agent\SqlAgent; +use Knobik\SqlAgent\Agent\ToolRegistry; use Knobik\SqlAgent\Enums\MessageRole; use Knobik\SqlAgent\Services\ConversationService; +use Knobik\SqlAgent\Tools\AskUserTool; class StreamAgentResponse { public function __construct( protected SqlAgent $agent, protected ConversationService $conversationService, + protected ToolRegistry $toolRegistry, ) {} public function __invoke(string $question, int $conversationId): void @@ -24,6 +28,8 @@ public function __invoke(string $question, int $conversationId): void $this->sendEvent('conversation', ['id' => $conversationId]); + $this->configureAskUserTool(); + $history = $this->conversationService->getHistory($conversationId); $fullContent = ''; @@ -128,6 +134,22 @@ protected function persistAndFinish( $this->sendEvent('done', $donePayload); } + protected function configureAskUserTool(): void + { + if (! $this->toolRegistry->has('ask_user')) { + return; + } + + $tool = $this->toolRegistry->get('ask_user'); + + if (! $tool instanceof AskUserTool) { + return; + } + + $tool->setRequestId(Str::uuid()->toString()); + $tool->setSendCallback(fn (array $data) => $this->sendEvent('ask_user', $data)); + } + protected function sendEvent(string $event, array $data): void { echo "event: {$event}\n"; diff --git a/src/Http/Controllers/AskUserReplyController.php b/src/Http/Controllers/AskUserReplyController.php new file mode 100644 index 0000000..abfc2e5 --- /dev/null +++ b/src/Http/Controllers/AskUserReplyController.php @@ -0,0 +1,20 @@ +getRequestId(), $request->getAnswer(), now()->addMinutes(10)); + + return response()->json(['status' => 'ok']); + } +} diff --git a/src/Http/Requests/AskUserReplyRequest.php b/src/Http/Requests/AskUserReplyRequest.php new file mode 100644 index 0000000..49f39d8 --- /dev/null +++ b/src/Http/Requests/AskUserReplyRequest.php @@ -0,0 +1,36 @@ + + */ + public function rules(): array + { + return [ + 'request_id' => 'required|string', + 'answer' => 'required|string|max:1000', + ]; + } + + public function getRequestId(): string + { + return $this->input('request_id'); + } + + public function getAnswer(): string + { + return $this->input('answer'); + } +} diff --git a/src/SqlAgentServiceProvider.php b/src/SqlAgentServiceProvider.php index c0ec34c..747454e 100644 --- a/src/SqlAgentServiceProvider.php +++ b/src/SqlAgentServiceProvider.php @@ -19,12 +19,6 @@ use Knobik\SqlAgent\Models\QueryPattern; use Knobik\SqlAgent\Search\SearchManager; use Knobik\SqlAgent\Services\ConnectionRegistry; -use Knobik\SqlAgent\Services\SchemaIntrospector; -use Knobik\SqlAgent\Tools\IntrospectSchemaTool; -use Knobik\SqlAgent\Tools\RunSqlTool; -use Knobik\SqlAgent\Tools\SaveLearningTool; -use Knobik\SqlAgent\Tools\SaveQueryTool; -use Knobik\SqlAgent\Tools\SearchKnowledgeTool; use Prism\Prism\Tool; class SqlAgentServiceProvider extends ServiceProvider @@ -46,30 +40,16 @@ public function register(): void $this->app->singleton(ToolRegistry::class, function ($app) { $registry = new ToolRegistry; - // Register built-in tools - $registry->registerMany([ - new RunSqlTool, - new IntrospectSchemaTool($app->make(SchemaIntrospector::class)), - new SearchKnowledgeTool($app->make(SearchManager::class)), - ]); - - if (config('sql-agent.learning.enabled')) { - $registry->registerMany([ - new SaveLearningTool, - new SaveQueryTool, - ]); - } - - // Register custom tools from config + // Register tools from config (includes built-in and user-added tools) foreach (config('sql-agent.agent.tools') as $toolClass) { if (! class_exists($toolClass)) { - throw new \InvalidArgumentException("Custom tool class [{$toolClass}] does not exist."); + throw new \InvalidArgumentException("Tool class [{$toolClass}] does not exist."); } $tool = $app->make($toolClass); if (! $tool instanceof Tool) { - throw new \InvalidArgumentException("Custom tool class [{$toolClass}] must extend ".Tool::class.'.'); + throw new \InvalidArgumentException("Tool class [{$toolClass}] must extend ".Tool::class.'.'); } $registry->register($tool); diff --git a/src/Tools/AskUserTool.php b/src/Tools/AskUserTool.php new file mode 100644 index 0000000..3c111a1 --- /dev/null +++ b/src/Tools/AskUserTool.php @@ -0,0 +1,137 @@ +as('ask_user') + ->for('Ask the user a clarifying question when their request is ambiguous. Use this when you need more information before proceeding, such as which time period, which metric, or which entity the user is referring to. You may provide suggested options with optional descriptions. Set multiple=true to let the user pick more than one option. The user can always type a custom free-text response instead of picking a suggestion.') + ->withStringParameter('question', 'The clarifying question to ask the user.') + ->withArrayParameter( + 'suggestions', + 'Optional list of suggested answers. Each suggestion has a label (shown on the button) and an optional description (shown below the label to provide context).', + new ObjectSchema( + 'suggestion', + 'A suggested answer option.', + [ + new StringSchema('label', 'The short label for this suggestion (displayed on the button).'), + new StringSchema('description', 'Optional longer description explaining this option.'), + ], + requiredFields: ['label'], + ), + required: false, + ) + ->withBooleanParameter('multiple', 'Set to true to allow the user to select multiple suggestions. Defaults to false (single-select).', required: false) + ->using($this); + } + + public function __invoke(string $question, ?array $suggestions = null, ?bool $multiple = null): string + { + $parsedSuggestions = $this->parseSuggestions($suggestions); + + if ($this->sendCallback === null) { + return 'User interaction is not available in non-streaming mode. Please make your best guess based on the available context and proceed.'; + } + + $cacheKey = "sql-agent:ask-user:{$this->requestId}:{$this->invocationCounter}"; + $this->invocationCounter++; + + ($this->sendCallback)([ + 'question' => $question, + 'suggestions' => $parsedSuggestions, + 'multiple' => $multiple ?? false, + 'request_id' => $cacheKey, + ]); + + $timeout = config('sql-agent.agent.ask_user_timeout'); + $elapsed = 0; + $pollInterval = 500_000; // 500ms in microseconds + + while ($elapsed < $timeout * 1_000_000) { + $answer = Cache::get($cacheKey); + + if ($answer !== null) { + Cache::forget($cacheKey); + + return "User answered: {$answer}"; + } + + if (connection_aborted()) { + return 'The user disconnected before answering. Please make your best guess based on the available context and proceed.'; + } + + usleep($pollInterval); + $elapsed += $pollInterval; + } + + return 'The user did not respond in time. Please make your best guess based on the available context and proceed.'; + } + + public function setSendCallback(?Closure $callback): self + { + $this->sendCallback = $callback; + + return $this; + } + + public function setRequestId(string $requestId): self + { + $this->requestId = $requestId; + + return $this; + } + + public function reset(): void + { + $this->invocationCounter = 0; + } + + /** + * @param array|null $suggestions + * @return array + */ + protected function parseSuggestions(?array $suggestions): array + { + if ($suggestions === null || empty($suggestions)) { + return []; + } + + $parsed = []; + + foreach ($suggestions as $suggestion) { + if (is_string($suggestion)) { + $trimmed = trim($suggestion); + if ($trimmed !== '') { + $parsed[] = ['label' => $trimmed, 'description' => null]; + } + } elseif (is_array($suggestion) && isset($suggestion['label']) && is_string($suggestion['label'])) { + $label = trim($suggestion['label']); + if ($label !== '') { + $description = isset($suggestion['description']) && is_string($suggestion['description']) + ? trim($suggestion['description']) + : null; + $parsed[] = ['label' => $label, 'description' => $description ?: null]; + } + } + } + + return $parsed; + } +} diff --git a/src/Tools/IntrospectSchemaTool.php b/src/Tools/IntrospectSchemaTool.php index 41daadc..c798962 100644 --- a/src/Tools/IntrospectSchemaTool.php +++ b/src/Tools/IntrospectSchemaTool.php @@ -26,7 +26,7 @@ public function __construct( $this ->as('introspect_schema') - ->for('Get detailed schema information about database tables. Can inspect a specific table or list all available tables.') + ->for('Get detailed schema information about tables, columns, relationships, and data types.') ->withStringParameter('table_name', 'Optional: The name of a specific table to inspect. If not provided, lists all tables.', required: false) ->withBooleanParameter('include_sample_data', 'Whether to include sample data from the table (up to 3 rows). This data is for understanding the schema only - never use it directly in responses to the user.', required: false) ->withEnumParameter( diff --git a/src/Tools/RunSqlTool.php b/src/Tools/RunSqlTool.php index bcaad2a..af15eed 100644 --- a/src/Tools/RunSqlTool.php +++ b/src/Tools/RunSqlTool.php @@ -35,7 +35,7 @@ public function __construct() $this ->as('run_sql') - ->for("Execute a SQL query against the database. Only {$allowed} statements are allowed. Returns query results as JSON.") + ->for("Execute a SQL query. Only {$allowed} statements allowed.") ->withStringParameter('sql', "The SQL query to execute. Must be a {$allowed} statement.") ->withEnumParameter( 'connection', diff --git a/src/Tools/SaveLearningTool.php b/src/Tools/SaveLearningTool.php index 03e9eac..69348c8 100644 --- a/src/Tools/SaveLearningTool.php +++ b/src/Tools/SaveLearningTool.php @@ -20,7 +20,7 @@ public function __construct() $this ->as('save_learning') - ->for('Save a new learning to the knowledge base. Use this when you discover something important about the database schema, business logic, or query patterns that would be useful for future queries.') + ->for('Save a discovery to the knowledge base (type errors, date formats, column quirks, business logic).') ->withStringParameter('title', 'A short, descriptive title for the learning (max 100 characters).') ->withStringParameter('description', 'A detailed description of what was learned and why it matters.') ->withEnumParameter('category', 'The category of this learning.', $categories) diff --git a/src/Tools/SaveQueryTool.php b/src/Tools/SaveQueryTool.php index bc237ac..aa873a3 100644 --- a/src/Tools/SaveQueryTool.php +++ b/src/Tools/SaveQueryTool.php @@ -15,7 +15,7 @@ public function __construct() { $this ->as('save_validated_query') - ->for('Save a validated query pattern to the knowledge base. Use this when you have successfully executed a SQL query that correctly answers a user question. This helps future queries by providing proven patterns.') + ->for('Save a successful query pattern for reuse. Use when a query correctly answers a common question.') ->withStringParameter('name', 'A short, descriptive name for the query pattern (max 100 characters).') ->withStringParameter('question', 'The natural language question this query answers.') ->withStringParameter('sql', 'The validated SQL query that correctly answers the question.') diff --git a/src/Tools/SearchKnowledgeTool.php b/src/Tools/SearchKnowledgeTool.php index 5824373..02ffc42 100644 --- a/src/Tools/SearchKnowledgeTool.php +++ b/src/Tools/SearchKnowledgeTool.php @@ -20,7 +20,7 @@ public function __construct( $this ->as('search_knowledge') - ->for('Search the knowledge base for relevant query patterns and learnings. Use this to find similar queries, understand business logic, or discover past learnings about the database.') + ->for('Search for relevant query patterns, learnings, and past discoveries about the database.') ->withStringParameter('query', 'The search query to find relevant knowledge.') ->withEnumParameter('type', "Filter results by index: 'all' (default) searches everything, or specify a specific index name.", $enumValues, required: false) ->withNumberParameter('limit', 'Maximum number of results to return.', required: false) diff --git a/tests/Feature/Http/AskUserReplyTest.php b/tests/Feature/Http/AskUserReplyTest.php new file mode 100644 index 0000000..1bae769 --- /dev/null +++ b/tests/Feature/Http/AskUserReplyTest.php @@ -0,0 +1,68 @@ +artisan('migrate'); + $this->user = Helpers::createAuthenticatedUser(); +}); + +describe('AskUserReplyController', function () { + it('writes answer to cache with valid data', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.ask-user-reply'), [ + 'request_id' => 'sql-agent:ask-user:test-uuid:0', + 'answer' => 'Last month', + ]); + + $response->assertOk(); + $response->assertJson(['status' => 'ok']); + + expect(Cache::get('sql-agent:ask-user:test-uuid:0'))->toBe('Last month'); + }); + + it('returns 422 when request_id is missing', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.ask-user-reply'), [ + 'answer' => 'Last month', + ]); + + $response->assertUnprocessable(); + $response->assertJsonValidationErrors('request_id'); + }); + + it('returns 422 when answer is missing', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.ask-user-reply'), [ + 'request_id' => 'sql-agent:ask-user:test-uuid:0', + ]); + + $response->assertUnprocessable(); + $response->assertJsonValidationErrors('answer'); + }); + + it('returns 422 when answer exceeds max length', function () { + $response = $this->actingAs($this->user) + ->postJson(route('sql-agent.ask-user-reply'), [ + 'request_id' => 'sql-agent:ask-user:test-uuid:0', + 'answer' => str_repeat('x', 1001), + ]); + + $response->assertUnprocessable(); + $response->assertJsonValidationErrors('answer'); + }); + + it('uses configured middleware', function () { + $route = app('router')->getRoutes()->getByName('sql-agent.ask-user-reply'); + + expect($route)->not->toBeNull(); + + $middleware = $route->middleware(); + + expect($middleware)->toContain('web'); + }); +}); diff --git a/tests/Unit/Agent/ToolLabelResolverTest.php b/tests/Unit/Agent/ToolLabelResolverTest.php index 3a0930b..49a7f2f 100644 --- a/tests/Unit/Agent/ToolLabelResolverTest.php +++ b/tests/Unit/Agent/ToolLabelResolverTest.php @@ -28,6 +28,10 @@ expect($this->resolver->getLabel('save_validated_query'))->toBe('Saving query pattern'); }); + it('returns label for ask_user', function () { + expect($this->resolver->getLabel('ask_user'))->toBe('Asking for clarification'); + }); + it('returns tool name for unknown tools', function () { expect($this->resolver->getLabel('custom_tool'))->toBe('custom_tool'); }); @@ -50,6 +54,10 @@ expect($this->resolver->getType('save_learning'))->toBe('save'); }); + it('returns ask type for ask_user', function () { + expect($this->resolver->getType('ask_user'))->toBe('ask'); + }); + it('returns default type for unknown tools', function () { expect($this->resolver->getType('unknown'))->toBe('default'); }); diff --git a/tests/Unit/CustomToolRegistrationTest.php b/tests/Unit/CustomToolRegistrationTest.php index 9abf1ab..0c44098 100644 --- a/tests/Unit/CustomToolRegistrationTest.php +++ b/tests/Unit/CustomToolRegistrationTest.php @@ -5,7 +5,8 @@ describe('Custom Tool Registration', function () { it('registers custom tools from config', function () { - config()->set('sql-agent.agent.tools', [FakeCustomTool::class]); + $tools = array_merge(config('sql-agent.agent.tools'), [FakeCustomTool::class]); + config()->set('sql-agent.agent.tools', $tools); $registry = app(ToolRegistry::class); @@ -14,7 +15,8 @@ }); it('custom tools appear alongside built-in tools', function () { - config()->set('sql-agent.agent.tools', [FakeCustomTool::class]); + $tools = array_merge(config('sql-agent.agent.tools'), [FakeCustomTool::class]); + config()->set('sql-agent.agent.tools', $tools); $registry = app(ToolRegistry::class); @@ -24,16 +26,18 @@ expect($registry->has('save_learning'))->toBeTrue(); expect($registry->has('save_validated_query'))->toBeTrue(); expect($registry->has('search_knowledge'))->toBeTrue(); + expect($registry->has('ask_user'))->toBeTrue(); // Custom tool should also be present expect($registry->has('fake_custom'))->toBeTrue(); - expect($registry->all())->toHaveCount(6); + expect($registry->all())->toHaveCount(7); }); it('resolves custom tools with constructor dependencies from the container', function () { $this->app->bind(FakeDependency::class, fn () => new FakeDependency('injected')); - config()->set('sql-agent.agent.tools', [FakeCustomToolWithDependency::class]); + $tools = array_merge(config('sql-agent.agent.tools'), [FakeCustomToolWithDependency::class]); + config()->set('sql-agent.agent.tools', $tools); $registry = app(ToolRegistry::class); @@ -44,13 +48,27 @@ config()->set('sql-agent.agent.tools', ['App\\NonExistent\\ToolClass']); app(ToolRegistry::class); - })->throws(InvalidArgumentException::class, 'Custom tool class [App\\NonExistent\\ToolClass] does not exist.'); + })->throws(InvalidArgumentException::class, 'Tool class [App\\NonExistent\\ToolClass] does not exist.'); it('throws exception for tool class that does not extend Tool', function () { config()->set('sql-agent.agent.tools', [NotATool::class]); app(ToolRegistry::class); })->throws(InvalidArgumentException::class, 'must extend'); + + it('allows disabling built-in tools by removing from config', function () { + // Remove AskUserTool from the tools array + $tools = array_filter( + config('sql-agent.agent.tools'), + fn ($t) => $t !== \Knobik\SqlAgent\Tools\AskUserTool::class, + ); + config()->set('sql-agent.agent.tools', $tools); + + $registry = app(ToolRegistry::class); + + expect($registry->has('ask_user'))->toBeFalse(); + expect($registry->has('run_sql'))->toBeTrue(); + }); }); // Test fixtures diff --git a/tests/Unit/RelayToolRegistrationTest.php b/tests/Unit/RelayToolRegistrationTest.php index b862eec..339e67d 100644 --- a/tests/Unit/RelayToolRegistrationTest.php +++ b/tests/Unit/RelayToolRegistrationTest.php @@ -46,13 +46,14 @@ ->andReturn([$relayTool]); Relay::swap($mock); - config()->set('sql-agent.agent.tools', [FakeCustomTool::class]); + $tools = array_merge(config('sql-agent.agent.tools'), [FakeCustomTool::class]); + config()->set('sql-agent.agent.tools', $tools); config()->set('sql-agent.agent.relay', ['calc-server']); $registry = app(ToolRegistry::class); - // Built-in (5) + custom (1) + relay (1) = 7 - expect($registry->all())->toHaveCount(7); + // Default tools (6) + custom (1) + relay (1) = 8 + expect($registry->all())->toHaveCount(8); expect($registry->has('run_sql'))->toBeTrue(); expect($registry->has('fake_custom'))->toBeTrue(); expect($registry->has('mcp_calculator'))->toBeTrue(); @@ -97,8 +98,8 @@ $registry = app(ToolRegistry::class); - // Only built-in tools - expect($registry->all())->toHaveCount(5); + // Built-in tools + default custom tools (AskUserTool) + expect($registry->all())->toHaveCount(6); }); }); diff --git a/tests/Unit/Tools/AskUserToolTest.php b/tests/Unit/Tools/AskUserToolTest.php new file mode 100644 index 0000000..7b27cb2 --- /dev/null +++ b/tests/Unit/Tools/AskUserToolTest.php @@ -0,0 +1,285 @@ +tool = new AskUserTool; +}); + +describe('AskUserTool registration', function () { + it('has the correct tool name', function () { + expect($this->tool->name())->toBe('ask_user'); + }); + + it('has question parameter', function () { + $params = $this->tool->parametersAsArray(); + + expect($params)->toHaveKey('question'); + }); + + it('has optional suggestions parameter', function () { + $params = $this->tool->parametersAsArray(); + + expect($params)->toHaveKey('suggestions'); + }); + + it('has optional multiple parameter', function () { + $params = $this->tool->parametersAsArray(); + + expect($params)->toHaveKey('multiple'); + }); + + it('has a description', function () { + expect($this->tool->description())->not->toBeEmpty(); + }); +}); + +describe('non-interactive mode', function () { + it('returns fallback when no callback is set', function () { + $result = ($this->tool)('What period?', [['label' => 'Last week'], ['label' => 'Last month']]); + + expect($result)->toContain('not available'); + expect($result)->toContain('best guess'); + }); +}); + +describe('suggestion parsing', function () { + it('handles null suggestions', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, null); + + expect($result)->toBe([]); + }); + + it('handles empty array suggestions', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, []); + + expect($result)->toBe([]); + }); + + it('parses object suggestions with label and description', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, [ + ['label' => 'Revenue trends', 'description' => 'Monthly revenue over time'], + ['label' => 'Customer segments'], + ]); + + expect($result)->toBe([ + ['label' => 'Revenue trends', 'description' => 'Monthly revenue over time'], + ['label' => 'Customer segments', 'description' => null], + ]); + }); + + it('trims labels and descriptions', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, [ + ['label' => ' Revenue ', 'description' => ' Monthly revenue '], + ]); + + expect($result)->toBe([ + ['label' => 'Revenue', 'description' => 'Monthly revenue'], + ]); + }); + + it('filters out entries with empty labels', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, [ + ['label' => ''], + ['label' => ' '], + ['label' => 'Valid'], + ]); + + expect($result)->toBe([ + ['label' => 'Valid', 'description' => null], + ]); + }); + + it('handles plain string suggestions as backward compatibility', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, ['First', 'Second', 'Third']); + + expect($result)->toBe([ + ['label' => 'First', 'description' => null], + ['label' => 'Second', 'description' => null], + ['label' => 'Third', 'description' => null], + ]); + }); + + it('filters non-string and non-array values', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, ['Valid', 123, null, ['label' => 'Also valid']]); + + expect($result)->toBe([ + ['label' => 'Valid', 'description' => null], + ['label' => 'Also valid', 'description' => null], + ]); + }); + + it('filters entries without a label key', function () { + $reflection = new ReflectionMethod($this->tool, 'parseSuggestions'); + $reflection->setAccessible(true); + + $result = $reflection->invoke($this->tool, [ + ['description' => 'No label here'], + ['label' => 'Has label'], + ]); + + expect($result)->toBe([ + ['label' => 'Has label', 'description' => null], + ]); + }); +}); + +describe('cache polling', function () { + it('returns answer when found in cache with suggestions', function () { + $sentData = null; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$sentData) { + $sentData = $data; + Cache::put($data['request_id'], 'Last month', now()->addMinutes(10)); + }); + + $result = ($this->tool)('What period?', [['label' => 'Last week'], ['label' => 'Last month']]); + + expect($result)->toBe('User answered: Last month'); + expect($sentData)->not->toBeNull(); + expect($sentData['question'])->toBe('What period?'); + expect($sentData['suggestions'])->toBe([ + ['label' => 'Last week', 'description' => null], + ['label' => 'Last month', 'description' => null], + ]); + expect($sentData['multiple'])->toBeFalse(); + expect($sentData['request_id'])->toStartWith('sql-agent:ask-user:test-uuid:'); + }); + + it('sends multiple flag when set to true', function () { + $sentData = null; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$sentData) { + $sentData = $data; + Cache::put($data['request_id'], 'A, B', now()->addMinutes(10)); + }); + + ($this->tool)('Pick items', [['label' => 'A'], ['label' => 'B']], true); + + expect($sentData['multiple'])->toBeTrue(); + }); + + it('returns answer when found in cache without suggestions', function () { + $sentData = null; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$sentData) { + $sentData = $data; + Cache::put($data['request_id'], 'Custom answer from user', now()->addMinutes(10)); + }); + + $result = ($this->tool)('What do you want to know?'); + + expect($result)->toBe('User answered: Custom answer from user'); + expect($sentData['suggestions'])->toBe([]); + expect($sentData['multiple'])->toBeFalse(); + }); + + it('cleans up cache key after reading', function () { + $cacheKey = null; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$cacheKey) { + $cacheKey = $data['request_id']; + Cache::put($data['request_id'], 'Yes', now()->addMinutes(10)); + }); + + ($this->tool)('Continue?', [['label' => 'Yes'], ['label' => 'No']]); + + expect(Cache::get($cacheKey))->toBeNull(); + }); + + it('returns timeout fallback when no answer received', function () { + config(['sql-agent.agent.ask_user_timeout' => 1]); // 1 second timeout + + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) { + // Don't write to cache - simulate no answer + }); + + $result = ($this->tool)('What period?', [['label' => 'Last week']]); + + expect($result)->toContain('did not respond'); + expect($result)->toContain('best guess'); + }); + + it('increments invocation counter for unique cache keys', function () { + $cacheKeys = []; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$cacheKeys) { + $cacheKeys[] = $data['request_id']; + Cache::put($data['request_id'], 'Answer', now()->addMinutes(10)); + }); + + ($this->tool)('First question?', [['label' => 'A'], ['label' => 'B']]); + ($this->tool)('Second question?', [['label' => 'C'], ['label' => 'D']]); + + expect($cacheKeys)->toHaveCount(2); + expect($cacheKeys[0])->not->toBe($cacheKeys[1]); + expect($cacheKeys[0])->toBe('sql-agent:ask-user:test-uuid:0'); + expect($cacheKeys[1])->toBe('sql-agent:ask-user:test-uuid:1'); + }); +}); + +describe('reset', function () { + it('resets invocation counter', function () { + $sentData = []; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function (array $data) use (&$sentData) { + $sentData[] = $data; + }); + + // First invocation uses counter 0 + Cache::shouldReceive('get')->once()->andReturn('answer'); + Cache::shouldReceive('forget')->once(); + ($this->tool)('Q1?', [['label' => 'A'], ['label' => 'B']]); + + expect($sentData[0]['request_id'])->toContain(':0'); + + // Reset and invoke again - counter should restart at 0 + $this->tool->reset(); + + Cache::shouldReceive('get')->once()->andReturn('answer'); + Cache::shouldReceive('forget')->once(); + ($this->tool)('Q2?', [['label' => 'A'], ['label' => 'B']]); + + expect($sentData[1]['request_id'])->toContain(':0'); + }); + + it('preserves callback and request id after reset', function () { + $called = false; + $this->tool->setRequestId('test-uuid'); + $this->tool->setSendCallback(function () use (&$called) { + $called = true; + }); + + $this->tool->reset(); + + Cache::shouldReceive('get')->once()->andReturn('answer'); + Cache::shouldReceive('forget')->once(); + $result = ($this->tool)('Question?', [['label' => 'A'], ['label' => 'B']]); + + expect($called)->toBeTrue(); + expect($result)->toContain('User answered: answer'); + }); +});