From 59bf78d2c430b02aa81238b8dcd4032b676e1021 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sun, 22 Mar 2026 12:07:58 +0800 Subject: [PATCH 1/4] Unify adapter message normalization across Claude and Gemini --- internal/adapter/claude/handler_util_test.go | 57 ++++++-- internal/adapter/claude/handler_utils.go | 137 +++++++++++++++--- internal/adapter/gemini/convert_messages.go | 5 + .../adapter/gemini/convert_messages_test.go | 78 ++++++++++ internal/adapter/openai/message_normalize.go | 54 ++++++- .../adapter/openai/message_normalize_test.go | 58 +++++--- .../adapter/openai/responses_input_items.go | 35 ++++- internal/prompt/messages.go | 6 + 8 files changed, 372 insertions(+), 58 deletions(-) create mode 100644 internal/adapter/gemini/convert_messages_test.go diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 136f1ce..169b0b2 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -1,7 +1,6 @@ package claude import ( - "strings" "testing" ) @@ -48,10 +47,49 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) { }, } got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected one normalized message, got %d", len(got)) + } m := got[0].(map[string]any) + if m["role"] != "tool" { + t.Fatalf("expected tool role preserved, got %#v", m["role"]) + } content, _ := m["content"].(string) - if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") { - t.Fatalf("expected serialized tool result marker, got %q", content) + if content != "tool output" { + t.Fatalf("expected raw tool output content preserved, got %q", content) + } +} + +func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "tool_use", + "id": "call_1", + "name": "search_web", + "input": map[string]any{"query": "latest"}, + }, + }, + }, + } + + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected one normalized tool-call message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected assistant role, got %#v", m["role"]) + } + tc, _ := m["tool_calls"].([]any) + if len(tc) != 1 { + t.Fatalf("expected one tool call, got %#v", m["tool_calls"]) + } + call, _ := tc[0].(map[string]any) + if call["id"] != "call_1" { + t.Fatalf("expected call id preserved, got %#v", call) } } @@ -94,8 +132,9 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { } got := normalizeClaudeMessages(msgs) m := got[0].(map[string]any) - if m["content"] != "Hello\nWorld" { - t.Fatalf("expected only text parts joined, got %q", m["content"]) + content, _ := m["content"].(string) + if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) { + t.Fatalf("expected text plus raw non-text block preserved, got %q", content) } } @@ -128,11 +167,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) { if !containsStr(prompt, "tool_use") { t.Fatalf("expected tool_use instruction in prompt") } - if !containsStr(prompt, "Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers yourself") { - t.Fatalf("expected marker guard instruction in prompt") + if containsStr(prompt, "TOOL_CALL_HISTORY") || containsStr(prompt, "TOOL_RESULT_HISTORY") { + t.Fatalf("expected legacy tool history markers removed from prompt") } - if containsStr(prompt, "tool_calls") { - t.Fatalf("expected prompt to avoid tool_calls JSON instruction") + if !containsStr(prompt, "Do not print tool-call JSON in text") { + t.Fatalf("expected prompt to keep no tool-call-json instruction") } } diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index 0a1fa75..3702202 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -13,28 +13,52 @@ func normalizeClaudeMessages(messages []any) []any { if !ok { continue } - copied := cloneMap(msg) + role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"]))) switch content := msg["content"].(type) { case []any: - parts := make([]string, 0, len(content)) + textParts := make([]string, 0, len(content)) + flushText := func() { + if len(textParts) == 0 { + return + } + out = append(out, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + textParts = textParts[:0] + } for _, block := range content { b, ok := block.(map[string]any) if !ok { continue } - typeStr, _ := b["type"].(string) - if typeStr == "text" { + typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"]))) + switch typeStr { + case "text": if t, ok := b["text"].(string); ok { - parts = append(parts, t) + textParts = append(textParts, t) + } + case "tool_use": + flushText() + if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil { + out = append(out, toolMsg) + } + case "tool_result": + flushText() + if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil { + out = append(out, toolMsg) + } + default: + if raw := strings.TrimSpace(formatClaudeBlockRaw(b)); raw != "" { + textParts = append(textParts, raw) } - } - if typeStr == "tool_result" { - parts = append(parts, formatClaudeToolResultForPrompt(b)) } } - copied["content"] = strings.Join(parts, "\n") + flushText() + default: + copied := cloneMap(msg) + out = append(out, copied) } - out = append(out, copied) } return out } @@ -52,9 +76,8 @@ func buildClaudeToolPrompt(tools []any) string { } parts = append(parts, "When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.", - "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", - "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", - "Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers yourself; they are system-side context only.", + "Tool roundtrip context is included directly in the conversation messages (assistant tool_use/tool_calls and tool results).", + "After receiving a valid tool result, continue with final answer instead of repeating the same call unless required fields are still missing.", ) return strings.Join(parts, "\n\n") } @@ -63,22 +86,94 @@ func formatClaudeToolResultForPrompt(block map[string]any) string { if block == nil { return "" } + payload := map[string]any{ + "type": "tool_result", + "content": block["content"], + } + if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" { + payload["tool_call_id"] = toolCallID + } else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" { + payload["tool_call_id"] = toolCallID + } + if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" { + payload["name"] = name + } + b, err := json.Marshal(payload) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", payload)) + } + return string(b) +} + +func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any { + if block == nil { + return nil + } + name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) + if name == "" { + return nil + } + callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"])) + if callID == "" { + callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) + } + if callID == "" { + callID = "call_claude" + } + arguments := block["input"] + if arguments == nil { + arguments = map[string]any{} + } + argsJSON, err := json.Marshal(arguments) + if err != nil || len(argsJSON) == 0 { + argsJSON = []byte("{}") + } + return map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": string(argsJSON), + }, + }, + }, + } +} + +func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any { + if block == nil { + return nil + } toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) if toolCallID == "" { toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) } if toolCallID == "" { - toolCallID = "unknown" + toolCallID = "call_claude" } - name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) - if name == "" { - name = "unknown" + out := map[string]any{ + "role": "tool", + "tool_call_id": toolCallID, + "content": block["content"], + } + if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" { + out["name"] = name + } + return out +} + +func formatClaudeBlockRaw(block map[string]any) string { + if block == nil { + return "" } - content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) - if content == "" { - content = "null" + b, err := json.Marshal(block) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", block)) } - return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) + return string(b) } func hasSystemMessage(messages []any) bool { diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go index 1148a7a..79a4de1 100644 --- a/internal/adapter/gemini/convert_messages.go +++ b/internal/adapter/gemini/convert_messages.go @@ -107,6 +107,11 @@ func geminiMessagesFromRequest(req map[string]any) []any { msg["name"] = name } out = append(out, msg) + continue + } + + if raw := strings.TrimSpace(stringifyJSON(part)); raw != "" && raw != "null" { + textParts = append(textParts, raw) } } flushText() diff --git a/internal/adapter/gemini/convert_messages_test.go b/internal/adapter/gemini/convert_messages_test.go new file mode 100644 index 0000000..b66b2b3 --- /dev/null +++ b/internal/adapter/gemini/convert_messages_test.go @@ -0,0 +1,78 @@ +package gemini + +import ( + "strings" + "testing" +) + +func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) { + req := map[string]any{ + "contents": []any{ + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "id": "call_g1", + "name": "search_web", + "args": map[string]any{"query": "ai"}, + }, + }, + }, + }, + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{ + "functionResponse": map[string]any{ + "id": "call_g1", + "name": "search_web", + "response": "ok", + }, + }, + }, + }, + }, + } + + got := geminiMessagesFromRequest(req) + if len(got) != 2 { + t.Fatalf("expected two normalized messages, got %#v", got) + } + assistant, _ := got[0].(map[string]any) + if assistant["role"] != "assistant" { + t.Fatalf("expected assistant first, got %#v", assistant) + } + tc, _ := assistant["tool_calls"].([]any) + if len(tc) != 1 { + t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"]) + } + toolMsg, _ := got[1].(map[string]any) + if toolMsg["role"] != "tool" || toolMsg["tool_call_id"] != "call_g1" { + t.Fatalf("expected tool message with call id, got %#v", toolMsg) + } +} + +func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) { + req := map[string]any{ + "contents": []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "hello"}, + map[string]any{"inlineData": map[string]any{"mimeType": "image/png"}}, + }, + }, + }, + } + + got := geminiMessagesFromRequest(req) + if len(got) != 1 { + t.Fatalf("expected one normalized message, got %#v", got) + } + msg, _ := got[0].(map[string]any) + content, _ := msg["content"].(string) + if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") { + t.Fatalf("expected unknown part preserved as raw json text, got %q", content) + } +} diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index a831599..0e844c9 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "fmt" "strings" "ds2api/internal/prompt" @@ -18,7 +19,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) switch role { case "assistant": - content := normalizeOpenAIContentForPrompt(msg["content"]) + content := buildAssistantContentForPrompt(msg) if content == "" { continue } @@ -27,12 +28,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an "content": content, }) case "tool", "function": - content := normalizeOpenAIContentForPrompt(msg["content"]) - if content == "" { - content = "null" - } + content := buildToolContentForPrompt(msg) out = append(out, map[string]any{ - "role": "user", + "role": "tool", "content": content, }) case "user", "system", "developer": @@ -57,6 +55,50 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an return out } +func buildAssistantContentForPrompt(msg map[string]any) string { + content := normalizeOpenAIContentForPrompt(msg["content"]) + toolCalls := normalizeAssistantToolCallsForPrompt(msg["tool_calls"]) + if toolCalls == "" { + return strings.TrimSpace(content) + } + if strings.TrimSpace(content) == "" { + return toolCalls + } + return strings.TrimSpace(content + "\n" + toolCalls) +} + +func normalizeAssistantToolCallsForPrompt(v any) string { + calls, ok := v.([]any) + if !ok || len(calls) == 0 { + return "" + } + b, err := json.Marshal(calls) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", calls)) + } + return strings.TrimSpace(string(b)) +} + +func buildToolContentForPrompt(msg map[string]any) string { + payload := map[string]any{ + "content": msg["content"], + } + if id := strings.TrimSpace(asString(msg["tool_call_id"])); id != "" { + payload["tool_call_id"] = id + } + if id := strings.TrimSpace(asString(msg["id"])); id != "" { + payload["id"] = id + } + if name := strings.TrimSpace(asString(msg["name"])); name != "" { + payload["name"] = name + } + content := normalizeOpenAIContentForPrompt(payload) + if strings.TrimSpace(content) == "" { + return `{"content":"null"}` + } + return content +} + func normalizeOpenAIContentForPrompt(v any) string { return prompt.NormalizeContent(v) } diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index fa17dfe..857e75c 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -34,11 +34,11 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *tes } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 3 { - t.Fatalf("expected 3 normalized messages, got %d", len(normalized)) + if len(normalized) != 4 { + t.Fatalf("expected 4 normalized messages with assistant tool_call history preserved, got %d", len(normalized)) } - toolContent, _ := normalized[2]["content"].(string) - if !strings.Contains(toolContent, `"temp":18`) { + toolContent, _ := normalized[3]["content"].(string) + if !strings.Contains(toolContent, `\"temp\":18`) { t.Fatalf("tool result should be transparently forwarded, got %q", toolContent) } if strings.Contains(toolContent, "[TOOL_RESULT_HISTORY]") { @@ -87,8 +87,8 @@ func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { normalized := normalizeOpenAIMessagesForPrompt(raw, "") got, _ := normalized[0]["content"].(string) - if !strings.Contains(got, "line-1\nline-2") { - t.Fatalf("expected joined text blocks, got %q", got) + if !strings.Contains(got, `"line-1"`) || !strings.Contains(got, `"line-2"`) || !strings.Contains(got, `"name":"read_file"`) { + t.Fatalf("expected tool envelope to preserve content blocks and metadata, got %q", got) } } @@ -108,11 +108,11 @@ func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { if len(normalized) != 1 { t.Fatalf("expected one normalized message, got %d", len(normalized)) } - if normalized[0]["role"] != "user" { - t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) + if normalized[0]["role"] != "tool" { + t.Fatalf("expected function role normalized as tool, got %#v", normalized[0]["role"]) } got, _ := normalized[0]["content"].(string) - if strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + if !strings.Contains(got, `"name":"legacy_tool"`) || !strings.Contains(got, `"ok":true`) { t.Fatalf("unexpected normalized function-role content: %q", got) } } @@ -135,12 +135,12 @@ func TestNormalizeOpenAIMessagesForPrompt_EmptyToolContentPreservedAsNull(t *tes if len(normalized) != 2 { t.Fatalf("expected tool completion turn to be preserved, got %#v", normalized) } - if normalized[0]["role"] != "user" { - t.Fatalf("expected tool role mapped to user, got %#v", normalized[0]["role"]) + if normalized[0]["role"] != "tool" { + t.Fatalf("expected tool role preserved, got %#v", normalized[0]["role"]) } got, _ := normalized[0]["content"].(string) - if got != "null" { - t.Fatalf("expected empty tool content to be preserved as null placeholder, got %q", got) + if !strings.Contains(got, `"content":""`) || !strings.Contains(got, `"name":"noop_tool"`) || !strings.Contains(got, `"tool_call_id":"call_5"`) { + t.Fatalf("expected tool metadata preserved in content envelope, got %q", got) } } @@ -170,8 +170,12 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantMultipleToolCallsRemainSepara } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 0 { - t.Fatalf("expected assistant tool_call-only message to be dropped in passthrough mode, got %#v", normalized) + if len(normalized) != 1 { + t.Fatalf("expected assistant tool_call-only message to be preserved, got %#v", normalized) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `"name":"search_web"`) || !strings.Contains(got, `"name":"eval_javascript"`) { + t.Fatalf("expected tool_calls payload preserved in assistant content, got %q", got) } } @@ -192,8 +196,12 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t * } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 0 { - t.Fatalf("expected no synthetic assistant message for tool_call-only content, got %#v", normalized) + if len(normalized) != 1 { + t.Fatalf("expected assistant tool_call-only content to be preserved, got %#v", normalized) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `{}{\"query\":\"测试工具调用\"}`) { + t.Fatalf("expected concatenated arguments preserved verbatim, got %q", got) } } @@ -214,8 +222,12 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDroppe } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 0 { - t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized) + if len(normalized) != 1 { + t.Fatalf("expected assistant tool_calls history to be preserved even when name missing, got %#v", normalized) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "call_missing_name") { + t.Fatalf("expected raw tool_call payload preserved, got %q", got) } } @@ -237,8 +249,12 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLi } normalized := normalizeOpenAIMessagesForPrompt(raw, "") - if len(normalized) != 0 { - t.Fatalf("expected nil-content assistant tool_call-only message to be dropped, got %#v", normalized) + if len(normalized) != 1 { + t.Fatalf("expected nil-content assistant tool_call-only message to be preserved, got %#v", normalized) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "send_file_to_user") { + t.Fatalf("expected tool call payload preserved, got %q", got) } } diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go index 91e6081..c12b58d 100644 --- a/internal/adapter/openai/responses_input_items.go +++ b/internal/adapter/openai/responses_input_items.go @@ -19,6 +19,27 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) if role != "" { + if role == "assistant" { + out := map[string]any{ + "role": "assistant", + } + if toolCalls, ok := m["tool_calls"].([]any); ok && len(toolCalls) > 0 { + out["tool_calls"] = toolCalls + } + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content != nil { + out["content"] = content + } + if _, hasToolCalls := out["tool_calls"]; hasToolCalls || out["content"] != nil { + return out + } + return nil + } content := m["content"] if content == nil { if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { @@ -28,10 +49,22 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str if content == nil { return nil } - return map[string]any{ + out := map[string]any{ "role": normalizeOpenAIRoleForPrompt(role), "content": content, } + if role == "tool" || role == "function" { + if callID := strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if name := strings.TrimSpace(asString(m["name"])); name != "" { + out["name"] = name + } + } + return out } itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go index e86c391..fca7b5c 100644 --- a/internal/prompt/messages.go +++ b/internal/prompt/messages.go @@ -36,6 +36,12 @@ func MessagesPrepare(messages []map[string]any) string { switch m.Role { case "assistant": parts = append(parts, "<|Assistant|>"+m.Text+"<|end▁of▁sentence|>") + case "tool": + if i > 0 { + parts = append(parts, "<|Tool|>"+m.Text) + } else { + parts = append(parts, m.Text) + } case "user", "system": if i > 0 { parts = append(parts, "<|User|>"+m.Text) From a50490562654490c774d35f4d0a320bfefc22a54 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sun, 22 Mar 2026 12:47:00 +0800 Subject: [PATCH 2/4] Fix Claude/Gemini prompt flattening for tool history and binary parts --- internal/adapter/claude/handler_util_test.go | 15 ++- internal/adapter/claude/handler_utils.go | 122 ++++++++++++++++-- internal/adapter/gemini/convert_messages.go | 88 ++++++++++++- .../adapter/gemini/convert_messages_test.go | 8 +- 4 files changed, 218 insertions(+), 15 deletions(-) diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 169b0b2..3212cca 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -1,6 +1,7 @@ package claude import ( + "strings" "testing" ) @@ -91,6 +92,10 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) { if call["id"] != "call_1" { t.Fatalf("expected call id preserved, got %#v", call) } + content, _ := m["content"].(string) + if !containsStr(content, "search_web") || !containsStr(content, `"arguments":"{\"query\":\"latest\"}"`) { + t.Fatalf("expected assistant content to include serialized tool call for prompt roundtrip, got %q", content) + } } func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { @@ -125,7 +130,7 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { "role": "user", "content": []any{ map[string]any{"type": "text", "text": "Hello"}, - map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}}, map[string]any{"type": "text", "text": "World"}, }, }, @@ -134,7 +139,13 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { m := got[0].(map[string]any) content, _ := m["content"].(string) if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) { - t.Fatalf("expected text plus raw non-text block preserved, got %q", content) + t.Fatalf("expected text plus non-text block marker preserved, got %q", content) + } + if !containsStr(content, omittedBinaryMarker) { + t.Fatalf("expected binary payload omitted marker, got %q", content) + } + if containsStr(content, strings.Repeat("A", 100)) { + t.Fatalf("expected raw base64 payload not to be included, got %q", content) } } diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index 3702202..50da3ec 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -6,6 +6,11 @@ import ( "strings" ) +const ( + maxClaudeRawPromptChars = 1024 + omittedBinaryMarker = "[omitted_binary_payload]" +) + func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { @@ -49,7 +54,7 @@ func normalizeClaudeMessages(messages []any) []any { out = append(out, toolMsg) } default: - if raw := strings.TrimSpace(formatClaudeBlockRaw(b)); raw != "" { + if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" { textParts = append(textParts, raw) } } @@ -128,19 +133,21 @@ func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any { if err != nil || len(argsJSON) == 0 { argsJSON = []byte("{}") } - return map[string]any{ - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "id": callID, - "type": "function", - "function": map[string]any{ - "name": name, - "arguments": string(argsJSON), - }, + toolCalls := []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": string(argsJSON), }, }, } + return map[string]any{ + "role": "assistant", + "content": marshalCompactJSON(toolCalls), + "tool_calls": toolCalls, + } } func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any { @@ -176,6 +183,99 @@ func formatClaudeBlockRaw(block map[string]any) string { return string(b) } +func formatClaudeUnknownBlockForPrompt(block map[string]any) string { + if block == nil { + return "" + } + safe := sanitizeClaudeBlockForPrompt(block) + raw := strings.TrimSpace(formatClaudeBlockRaw(safe)) + if raw == "" { + return "" + } + if len(raw) > maxClaudeRawPromptChars { + return raw[:maxClaudeRawPromptChars] + "...(truncated)" + } + return raw +} + +func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any { + out := cloneMap(block) + for k, v := range out { + if looksLikeBinaryFieldName(k) { + out[k] = omittedBinaryMarker + continue + } + switch inner := v.(type) { + case map[string]any: + out[k] = sanitizeClaudeBlockForPrompt(inner) + case []any: + out[k] = sanitizeClaudeArrayForPrompt(inner) + case string: + out[k] = sanitizeClaudeStringForPrompt(k, inner) + } + } + return out +} + +func sanitizeClaudeArrayForPrompt(items []any) []any { + out := make([]any, 0, len(items)) + for _, item := range items { + switch v := item.(type) { + case map[string]any: + out = append(out, sanitizeClaudeBlockForPrompt(v)) + case []any: + out = append(out, sanitizeClaudeArrayForPrompt(v)) + default: + out = append(out, v) + } + } + return out +} + +func sanitizeClaudeStringForPrompt(key, value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) { + return omittedBinaryMarker + } + if len(trimmed) > maxClaudeRawPromptChars { + return trimmed[:maxClaudeRawPromptChars] + "...(truncated)" + } + return trimmed +} + +func looksLikeBinaryFieldName(name string) bool { + n := strings.ToLower(strings.TrimSpace(name)) + return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata" +} + +func looksLikeBase64Payload(v string) bool { + if len(v) < 512 { + return false + } + compact := strings.TrimRight(v, "=") + if compact == "" { + return false + } + for _, ch := range compact { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { + continue + } + return false + } + return true +} + +func marshalCompactJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + func hasSystemMessage(messages []any) bool { for _, m := range messages { msg, ok := m.(map[string]any) diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go index 79a4de1..ec3f174 100644 --- a/internal/adapter/gemini/convert_messages.go +++ b/internal/adapter/gemini/convert_messages.go @@ -2,6 +2,8 @@ package gemini import "strings" +const maxGeminiRawPromptChars = 1024 + func geminiMessagesFromRequest(req map[string]any) []any { out := make([]any, 0, 8) if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { @@ -110,7 +112,7 @@ func geminiMessagesFromRequest(req map[string]any) []any { continue } - if raw := strings.TrimSpace(stringifyJSON(part)); raw != "" && raw != "null" { + if raw := strings.TrimSpace(formatGeminiUnknownPartForPrompt(part)); raw != "" && raw != "null" { textParts = append(textParts, raw) } } @@ -156,3 +158,87 @@ func mapGeminiRole(v any) string { return "" } } + +func formatGeminiUnknownPartForPrompt(part map[string]any) string { + safe := sanitizeGeminiPartForPrompt(part) + raw := strings.TrimSpace(stringifyJSON(safe)) + if raw == "" { + return "" + } + if len(raw) > maxGeminiRawPromptChars { + return raw[:maxGeminiRawPromptChars] + "...(truncated)" + } + return raw +} + +func sanitizeGeminiPartForPrompt(part map[string]any) map[string]any { + out := make(map[string]any, len(part)) + for k, v := range part { + if looksLikeGeminiBinaryField(k) { + out[k] = "[omitted_binary_payload]" + continue + } + switch x := v.(type) { + case map[string]any: + out[k] = sanitizeGeminiPartForPrompt(x) + case []any: + out[k] = sanitizeGeminiArrayForPrompt(x) + case string: + out[k] = sanitizeGeminiStringForPrompt(k, x) + default: + out[k] = v + } + } + return out +} + +func sanitizeGeminiArrayForPrompt(items []any) []any { + out := make([]any, 0, len(items)) + for _, item := range items { + switch x := item.(type) { + case map[string]any: + out = append(out, sanitizeGeminiPartForPrompt(x)) + case []any: + out = append(out, sanitizeGeminiArrayForPrompt(x)) + default: + out = append(out, x) + } + } + return out +} + +func sanitizeGeminiStringForPrompt(key, value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if looksLikeGeminiBinaryField(key) || looksLikeGeminiBase64(trimmed) { + return "[omitted_binary_payload]" + } + if len(trimmed) > maxGeminiRawPromptChars { + return trimmed[:maxGeminiRawPromptChars] + "...(truncated)" + } + return trimmed +} + +func looksLikeGeminiBinaryField(name string) bool { + n := strings.ToLower(strings.TrimSpace(name)) + return n == "data" || n == "bytes" || n == "inlinedata" || n == "inline_data" || n == "base64" +} + +func looksLikeGeminiBase64(v string) bool { + if len(v) < 512 { + return false + } + compact := strings.TrimRight(v, "=") + if compact == "" { + return false + } + for _, ch := range compact { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { + continue + } + return false + } + return true +} diff --git a/internal/adapter/gemini/convert_messages_test.go b/internal/adapter/gemini/convert_messages_test.go index b66b2b3..4c98778 100644 --- a/internal/adapter/gemini/convert_messages_test.go +++ b/internal/adapter/gemini/convert_messages_test.go @@ -60,7 +60,7 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T "role": "user", "parts": []any{ map[string]any{"text": "hello"}, - map[string]any{"inlineData": map[string]any{"mimeType": "image/png"}}, + map[string]any{"inlineData": map[string]any{"mimeType": "image/png", "data": strings.Repeat("A", 2048)}}, }, }, }, @@ -75,4 +75,10 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") { t.Fatalf("expected unknown part preserved as raw json text, got %q", content) } + if !strings.Contains(content, "[omitted_binary_payload]") { + t.Fatalf("expected inlineData payload to be redacted, got %q", content) + } + if strings.Contains(content, strings.Repeat("A", 100)) { + t.Fatalf("expected raw base64 payload not to be embedded, got %q", content) + } } From a6499cbece487915ac0ed01fcbe6afdd0ac19227 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sun, 22 Mar 2026 13:05:41 +0800 Subject: [PATCH 3/4] Split Claude sanitize helpers to satisfy refactor line gate --- internal/adapter/claude/handler_utils.go | 98 ---------------- .../adapter/claude/handler_utils_sanitize.go | 105 ++++++++++++++++++ 2 files changed, 105 insertions(+), 98 deletions(-) create mode 100644 internal/adapter/claude/handler_utils_sanitize.go diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index 50da3ec..ac94291 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -6,11 +6,6 @@ import ( "strings" ) -const ( - maxClaudeRawPromptChars = 1024 - omittedBinaryMarker = "[omitted_binary_payload]" -) - func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { @@ -183,99 +178,6 @@ func formatClaudeBlockRaw(block map[string]any) string { return string(b) } -func formatClaudeUnknownBlockForPrompt(block map[string]any) string { - if block == nil { - return "" - } - safe := sanitizeClaudeBlockForPrompt(block) - raw := strings.TrimSpace(formatClaudeBlockRaw(safe)) - if raw == "" { - return "" - } - if len(raw) > maxClaudeRawPromptChars { - return raw[:maxClaudeRawPromptChars] + "...(truncated)" - } - return raw -} - -func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any { - out := cloneMap(block) - for k, v := range out { - if looksLikeBinaryFieldName(k) { - out[k] = omittedBinaryMarker - continue - } - switch inner := v.(type) { - case map[string]any: - out[k] = sanitizeClaudeBlockForPrompt(inner) - case []any: - out[k] = sanitizeClaudeArrayForPrompt(inner) - case string: - out[k] = sanitizeClaudeStringForPrompt(k, inner) - } - } - return out -} - -func sanitizeClaudeArrayForPrompt(items []any) []any { - out := make([]any, 0, len(items)) - for _, item := range items { - switch v := item.(type) { - case map[string]any: - out = append(out, sanitizeClaudeBlockForPrompt(v)) - case []any: - out = append(out, sanitizeClaudeArrayForPrompt(v)) - default: - out = append(out, v) - } - } - return out -} - -func sanitizeClaudeStringForPrompt(key, value string) string { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return "" - } - if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) { - return omittedBinaryMarker - } - if len(trimmed) > maxClaudeRawPromptChars { - return trimmed[:maxClaudeRawPromptChars] + "...(truncated)" - } - return trimmed -} - -func looksLikeBinaryFieldName(name string) bool { - n := strings.ToLower(strings.TrimSpace(name)) - return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata" -} - -func looksLikeBase64Payload(v string) bool { - if len(v) < 512 { - return false - } - compact := strings.TrimRight(v, "=") - if compact == "" { - return false - } - for _, ch := range compact { - if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { - continue - } - return false - } - return true -} - -func marshalCompactJSON(v any) string { - b, err := json.Marshal(v) - if err != nil { - return strings.TrimSpace(fmt.Sprintf("%v", v)) - } - return string(b) -} - func hasSystemMessage(messages []any) bool { for _, m := range messages { msg, ok := m.(map[string]any) diff --git a/internal/adapter/claude/handler_utils_sanitize.go b/internal/adapter/claude/handler_utils_sanitize.go new file mode 100644 index 0000000..10980cb --- /dev/null +++ b/internal/adapter/claude/handler_utils_sanitize.go @@ -0,0 +1,105 @@ +package claude + +import ( + "encoding/json" + "fmt" + "strings" +) + +const ( + maxClaudeRawPromptChars = 1024 + omittedBinaryMarker = "[omitted_binary_payload]" +) + +func formatClaudeUnknownBlockForPrompt(block map[string]any) string { + if block == nil { + return "" + } + safe := sanitizeClaudeBlockForPrompt(block) + raw := strings.TrimSpace(formatClaudeBlockRaw(safe)) + if raw == "" { + return "" + } + if len(raw) > maxClaudeRawPromptChars { + return raw[:maxClaudeRawPromptChars] + "...(truncated)" + } + return raw +} + +func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any { + out := cloneMap(block) + for k, v := range out { + if looksLikeBinaryFieldName(k) { + out[k] = omittedBinaryMarker + continue + } + switch inner := v.(type) { + case map[string]any: + out[k] = sanitizeClaudeBlockForPrompt(inner) + case []any: + out[k] = sanitizeClaudeArrayForPrompt(inner) + case string: + out[k] = sanitizeClaudeStringForPrompt(k, inner) + } + } + return out +} + +func sanitizeClaudeArrayForPrompt(items []any) []any { + out := make([]any, 0, len(items)) + for _, item := range items { + switch v := item.(type) { + case map[string]any: + out = append(out, sanitizeClaudeBlockForPrompt(v)) + case []any: + out = append(out, sanitizeClaudeArrayForPrompt(v)) + default: + out = append(out, v) + } + } + return out +} + +func sanitizeClaudeStringForPrompt(key, value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) { + return omittedBinaryMarker + } + if len(trimmed) > maxClaudeRawPromptChars { + return trimmed[:maxClaudeRawPromptChars] + "...(truncated)" + } + return trimmed +} + +func looksLikeBinaryFieldName(name string) bool { + n := strings.ToLower(strings.TrimSpace(name)) + return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata" +} + +func looksLikeBase64Payload(v string) bool { + if len(v) < 512 { + return false + } + compact := strings.TrimRight(v, "=") + if compact == "" { + return false + } + for _, ch := range compact { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { + continue + } + return false + } + return true +} + +func marshalCompactJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} From 6802a3d53e2c6386ed7aa7f72a68fc82ee8d5e2d Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sun, 22 Mar 2026 13:42:01 +0800 Subject: [PATCH 4/4] Fix Claude tool block normalization and tool_result fidelity --- internal/adapter/claude/handler_util_test.go | 73 ++++++++++++++++++++ internal/adapter/claude/handler_utils.go | 29 ++++++-- 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 3212cca..9ad10e3 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -98,6 +98,38 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) { } } +func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "tool_use", + "id": "call_unsafe", + "name": "dangerous_tool", + "input": map[string]any{"value": "x"}, + }, + }, + }, + } + + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected one normalized message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["role"] != "user" { + t.Fatalf("expected user role preserved, got %#v", m["role"]) + } + if _, ok := m["tool_calls"]; ok { + t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"]) + } + content, _ := m["content"].(string) + if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") { + t.Fatalf("expected raw tool_use block preserved in user content, got %q", content) + } +} + func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { msgs := []any{"not a map", 42} got := normalizeClaudeMessages(msgs) @@ -149,6 +181,47 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { } } +func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) { + msgs := []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "call_image_1", + "name": "vision_tool", + "content": []any{ + map[string]any{"type": "text", "text": "image analysis"}, + map[string]any{ + "type": "image", + "source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)}, + }, + }, + }, + }, + }, + } + + got := normalizeClaudeMessages(msgs) + if len(got) != 1 { + t.Fatalf("expected one normalized message, got %d", len(got)) + } + m := got[0].(map[string]any) + if m["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", m["role"]) + } + content, _ := m["content"].(string) + if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) { + t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content) + } + if !containsStr(content, omittedBinaryMarker) { + t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content) + } + if containsStr(content, strings.Repeat("B", 100)) { + t.Fatalf("expected raw base64 payload not to be included, got %q", content) + } +} + // ─── buildClaudeToolPrompt ─────────────────────────────────────────── func TestBuildClaudeToolPromptSingleTool(t *testing.T) { diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index ac94291..97327b4 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -39,9 +39,15 @@ func normalizeClaudeMessages(messages []any) []any { textParts = append(textParts, t) } case "tool_use": - flushText() - if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil { - out = append(out, toolMsg) + if role == "assistant" { + flushText() + if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil { + out = append(out, toolMsg) + } + continue + } + if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" { + textParts = append(textParts, raw) } case "tool_result": flushText() @@ -159,7 +165,7 @@ func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any out := map[string]any{ "role": "tool", "tool_call_id": toolCallID, - "content": block["content"], + "content": normalizeClaudeToolResultContent(block["content"]), } if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" { out["name"] = name @@ -167,6 +173,21 @@ func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any return out } +func normalizeClaudeToolResultContent(content any) any { + if text, ok := content.(string); ok { + return text + } + payload := map[string]any{ + "type": "tool_result", + "content": content, + } + b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload)) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", content)) + } + return string(b) +} + func formatClaudeBlockRaw(block map[string]any) string { if block == nil { return ""