From 2b97cb98b586d1bd4d9d9496205a9a40394f1018 Mon Sep 17 00:00:00 2001 From: xxddff <772327379@qq.com> Date: Tue, 10 Feb 2026 17:35:54 +0900 Subject: [PATCH 01/15] Delete 'user' field from raw JSON Remove the 'user' field from the raw JSON as requested. --- .../codex/openai/responses/codex_openai-responses_request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index 828c4d87..692cfaa6 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -27,6 +27,9 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p") rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") + // Delete user field as requested + rawJSON, _ = sjson.DeleteBytes(rawJSON, "user") + // Convert role "system" to "developer" in input array to comply with Codex API requirements. rawJSON = convertSystemRoleToDeveloper(rawJSON) From 865af9f19ea90c2684b8e1703732a3451932f679 Mon Sep 17 00:00:00 2001 From: xxddff <772327379@qq.com> Date: Tue, 10 Feb 2026 17:38:49 +0900 Subject: [PATCH 02/15] Implement test for user field deletion Add test to verify deletion of user field in response --- .../codex_openai-responses_request_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go index ea413238..2d1d47a1 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go @@ -263,3 +263,20 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) { t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String()) } } + +func TestUserFieldDeletion(t *testing.T) { + inputJSON := []byte(`{ + "model": "gpt-5.2", + "user": "test-user", + "input": [{"role": "user", "content": "Hello"}] + }`) + + output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false) + outputStr := string(output) + + // Verify user field is deleted + userField := gjson.Get(outputStr, "user") + if userField.Exists() { + t.Error("user field should be deleted") + } +} From afe4c1bfb7dfd2d0259ebc306e098c2cff33038d Mon Sep 17 00:00:00 2001 From: xxddff <772327379@qq.com> Date: Tue, 10 Feb 2026 18:24:26 +0900 Subject: [PATCH 03/15] =?UTF-8?q?=E6=9B=B4=E6=96=B0internal/translator/cod?= =?UTF-8?q?ex/openai/responses/codex=5Fopenai-responses=5Frequest.go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../codex/openai/responses/codex_openai-responses_request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index 692cfaa6..f0407149 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -27,8 +27,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p") rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") - // Delete user field as requested - rawJSON, _ = sjson.DeleteBytes(rawJSON, "user") + // Delete the user field as it is not supported by the Codex upstream. + rawJSON, _ = sjson.DeleteBytes(rawJSON, "user") // Convert role "system" to "developer" in input array to comply with Codex API requirements. rawJSON = convertSystemRoleToDeveloper(rawJSON) From bb9fe52f1e8aa592fd7a5b3c40bd9dd1b8f7c38d Mon Sep 17 00:00:00 2001 From: xxddff <772327379@qq.com> Date: Tue, 10 Feb 2026 18:24:58 +0900 Subject: [PATCH 04/15] Update internal/translator/codex/openai/responses/codex_openai-responses_request_test.go Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../openai/responses/codex_openai-responses_request_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go index 2d1d47a1..4f562486 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go @@ -276,7 +276,7 @@ func TestUserFieldDeletion(t *testing.T) { // Verify user field is deleted userField := gjson.Get(outputStr, "user") - if userField.Exists() { - t.Error("user field should be deleted") - } + if userField.Exists() { + t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw) + } } From 8192eeabc8d2bdebc2648397a9a8a78d6bb136f1 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Sat, 7 Feb 2026 04:02:42 +0800 Subject: [PATCH 05/15] Revert "feat: inject web_search alternative hint instead of silently filtering" This reverts commit 3c7a5afdcca5e09e3ee9bea438e57bdd0d6c511d. --- .../kiro/claude/kiro_claude_request.go | 32 +++---------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index 425d9ae2..316bf9ff 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -219,26 +219,7 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA } // Convert Claude tools to Kiro format - kiroTools, hasWebSearch := convertClaudeToolsToKiro(tools) - - // If web_search was requested but filtered, inject alternative hint - if hasWebSearch { - webSearchHint := `[CRITICAL WEB ACCESS INSTRUCTION] -You have the Fetch/read_url_content tool available. When the user asks about current events, weather, news, or any information that requires web access: -- DO NOT say you cannot search the web -- DO NOT refuse to help with web-related queries -- IMMEDIATELY use the Fetch tool to access relevant URLs -- Use well-known official websites, documentation sites, or API endpoints -- Construct appropriate URLs based on the query context - -IMPORTANT: Always attempt to fetch information FIRST before declining. You CAN access the web via Fetch.` - if systemPrompt != "" { - systemPrompt += "\n" + webSearchHint - } else { - systemPrompt = webSearchHint - } - log.Infof("kiro: injected web_search alternative hint (tool was filtered)") - } + kiroTools := convertClaudeToolsToKiro(tools) // Thinking mode implementation: // Kiro API supports official thinking/reasoning mode via tag. @@ -527,13 +508,11 @@ func ensureKiroInputSchema(parameters interface{}) interface{} { } } -// convertClaudeToolsToKiro converts Claude tools to Kiro format. -// Returns the converted tools and a boolean indicating if web_search was filtered. -func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) { +// convertClaudeToolsToKiro converts Claude tools to Kiro format +func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { var kiroTools []KiroToolWrapper - hasWebSearch := false if !tools.IsArray() { - return kiroTools, hasWebSearch + return kiroTools } for _, tool := range tools.Array() { @@ -544,7 +523,6 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) { nameLower := strings.ToLower(name) if nameLower == "web_search" || nameLower == "websearch" { log.Debugf("kiro: skipping unsupported tool: %s", name) - hasWebSearch = true continue } @@ -591,7 +569,7 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) { // This prevents 500 errors when Claude Code sends too many tools kiroTools = compressToolsIfNeeded(kiroTools) - return kiroTools, hasWebSearch + return kiroTools } // processMessages processes Claude messages and builds Kiro history From fe6fc628edf65e1bf40c09cafe861020a0f7d620 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Sat, 7 Feb 2026 04:09:47 +0800 Subject: [PATCH 06/15] Revert "fix: filter out web_search/websearch tools unsupported by Kiro API" This reverts commit 5dc936a9a45b459eb6a2a950492f24a5b4f39f0f. --- .../translator/kiro/claude/kiro_claude_request.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index 316bf9ff..b3742f22 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -17,6 +17,7 @@ import ( "github.com/tidwall/gjson" ) + // Kiro API request structs - field order determines JSON key order // KiroPayload is the top-level request structure for Kiro API @@ -33,6 +34,7 @@ type KiroInferenceConfig struct { TopP float64 `json:"topP,omitempty"` } + // KiroConversationState holds the conversation context type KiroConversationState struct { ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field @@ -378,6 +380,7 @@ func hasThinkingTagInBody(body []byte) bool { return strings.Contains(bodyStr, "") || strings.Contains(bodyStr, "") } + // IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header. // Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking. func IsThinkingEnabledFromHeader(headers http.Header) bool { @@ -517,15 +520,6 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { for _, tool := range tools.Array() { name := tool.Get("name").String() - - // Filter out web_search/websearch tools (Kiro API doesn't support them) - // This matches the behavior in AIClient-2-API/claude-kiro.js - nameLower := strings.ToLower(name) - if nameLower == "web_search" || nameLower == "websearch" { - log.Debugf("kiro: skipping unsupported tool: %s", name) - continue - } - description := tool.Get("description").String() inputSchemaResult := tool.Get("input_schema") var inputSchema interface{} From 7b01ca0e2ecf71170765ac1c69f7a4854fbf3e4c Mon Sep 17 00:00:00 2001 From: Skyuno Date: Tue, 10 Feb 2026 21:59:15 +0800 Subject: [PATCH 07/15] fix(kiro): implement web search MCP integration for streaming and non-streaming paths Add complete web search functionality that routes pure web_search requests to the Kiro MCP endpoint instead of the normal GAR API. Executor changes (kiro_executor.go): - Add web_search detection in Execute() and ExecuteStream() entry points using HasWebSearchTool() to intercept pure web_search requests before normal processing - Add 'kiro' format passthrough in buildKiroPayloadForFormat() for pre-built payloads used by callKiroRawAndBuffer() - Implement handleWebSearchStream(): streaming search loop with MCP search -> InjectToolResultsClaude -> callKiroAndBuffer, supporting up to 5 search iterations with model-driven re-search - Implement handleWebSearch(): non-streaming path that performs single MCP search, injects tool results, calls normal Execute path, and appends server_tool_use indicators to response - Add helper methods: callKiroAndBuffer(), callKiroRawAndBuffer(), callKiroDirectStream(), sendFallbackText(), executeNonStreamFallback() Web search core logic (kiro_websearch.go) [NEW]: - Define MCP JSON-RPC 2.0 types (McpRequest, McpResponse, McpResult, McpContent, McpError) - Define WebSearchResults/WebSearchResult structs for parsing MCP search results - HasWebSearchTool(): detect pure web_search requests (single-tool array only) - ContainsWebSearchTool(): detect web_search in mixed-tool arrays - ExtractSearchQuery(): parse search query from Claude Code's tool_use message format - CreateMcpRequest(): build MCP tools/call request with Kiro-compatible ID format - InjectToolResultsClaude(): append assistant tool_use + user tool_result messages to Claude-format payload for GAR translation pipeline - InjectToolResults(): modify Kiro-format payload directly with toolResults in currentMessage context - InjectSearchIndicatorsInResponse(): prepend server_tool_use + web_search_tool_result content blocks to non-streaming response for Claude Code search count display - ReplaceWebSearchToolDescription(): swap restrictive Kiro tool description with minimal re-search-friendly version - StripWebSearchTool(): remove web_search from tools array - FormatSearchContextPrompt() / FormatToolResultText(): format search results for injection - SSE event generation: SseEvent type, GenerateWebSearchEvents() (11-event sequence), GenerateSearchIndicatorEvents() (server_tool_use + web_search_tool_result pairs) - Stream analysis: AnalyzeBufferedStream() to detect stop_reason and web_search tool_use in buffered chunks, FilterChunksForClient() to strip tool_use blocks and adjust indices, AdjustSSEChunk() / AdjustStreamIndices() for content block index offset management MCP API handler (kiro_websearch_handler.go) [NEW]: - WebSearchHandler struct with MCP endpoint, HTTP client, auth token, fingerprint, and custom auth attributes - FetchToolDescription(): sync.Once-guarded MCP tools/list call to cache web_search tool description - GetWebSearchDescription(): thread-safe cached description retrieval - CallMcpAPI(): MCP API caller with retry logic (exponential backoff, retryable on 502/503/504), AWS-aligned headers via setMcpHeaders() - ParseSearchResults(): extract WebSearchResults from MCP JSON-RPC response - setMcpHeaders(): set Content-Type, Kiro agent headers, dynamic fingerprint User-Agent, AWS SDK identifiers, Bearer auth, and custom auth attributes Claude request translation (kiro_claude_request.go): - Rename web_search -> remote_web_search in convertClaudeToolsToKiro() with dynamic description from GetWebSearchDescription() or hardcoded fallback - Rename web_search -> remote_web_search in BuildAssistantMessageStruct() for tool_use content blocks - Add remoteWebSearchDescription constant as fallback when MCP tools/list hasn't been fetched --- internal/runtime/executor/kiro_executor.go | 559 +++++++- .../kiro/claude/kiro_claude_request.go | 21 +- .../translator/kiro/claude/kiro_websearch.go | 1169 +++++++++++++++++ .../kiro/claude/kiro_websearch_handler.go | 270 ++++ 4 files changed, 2013 insertions(+), 6 deletions(-) create mode 100644 internal/translator/kiro/claude/kiro_websearch.go create mode 100644 internal/translator/kiro/claude/kiro_websearch_handler.go diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 26dbc2ec..c360b2de 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -519,8 +519,12 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, case "openai": log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) + case "kiro": + // Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer) + log.Debugf("kiro: body already in Kiro format, passing through directly") + return body, false default: - // Default to Claude format (also handles "claude", "kiro", etc.) + // Default to Claude format log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) } @@ -636,6 +640,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: rate limiter cleared for token %s", tokenKey) + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint") + return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn) + } + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -1057,6 +1068,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") + return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) + } + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -4096,6 +4114,539 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } -// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go -// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go -// The executor now uses kiroclaude.* and kirocommon.* functions instead +const maxWebSearchIterations = 5 + +// handleWebSearchStream handles web_search requests: +// Step 1: tools/list (sync) → fetch/cache tool description +// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop +// Note: We skip the "model decides to search" step because Claude Code already +// decided to use web_search. The Kiro tool description restricts non-coding +// topics, so asking the model again would cause it to refuse valid searches. +func (e *KiroExecutor) handleWebSearchStream( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (<-chan cliproxyexecutor.StreamChunk, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow") + return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint based on region + region := kiroDefaultRegion + if auth != nil && auth.Metadata != nil { + if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { + region = r + } + } + mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) + + // ── Step 1: tools/list (SYNC) — cache tool description ── + { + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + var authAttrs map[string]string + if auth != nil { + authAttrs = auth.Attributes + } + kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + } + + // Create output channel + out := make(chan cliproxyexecutor.StreamChunk) + + go func() { + defer close(out) + + // Send message_start event to client + messageStartEvent := kiroclaude.SseEvent{ + Event: "message_start", + Data: map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": kiroclaude.GenerateMessageID(), + "type": "message", + "role": "assistant", + "model": req.Model, + "content": []interface{}{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": len(req.Payload) / 4, + "output_tokens": 0, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + }, + }, + } + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}: + } + + // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ── + contentBlockIndex := 0 + currentQuery := query + + // Replace web_search tool description with a minimal one that allows re-search. + // The original tools/list description from Kiro restricts non-coding topics, + // but we've already decided to search. We keep the tool so the model can + // request additional searches when results are insufficient. + simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) + if simplifyErr != nil { + log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr) + simplifiedPayload = bytes.Clone(req.Payload) + } + + currentClaudePayload := simplifiedPayload + totalSearches := 0 + + // Generate toolUseId for the first iteration (Claude Code already decided to search) + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + + for iteration := 0; iteration < maxWebSearchIterations; iteration++ { + log.Infof("kiro/websearch: search iteration %d/%d — query: %s", + iteration+1, maxWebSearchIterations, currentQuery) + + // MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + var authAttrs map[string]string + if auth != nil { + authAttrs = auth.Attributes + } + handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + totalSearches++ + log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount) + + // Send search indicator events to client + searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex) + for _, event := range searchEvents { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + } + } + contentBlockIndex += 2 + + // Inject tool_use + tool_result into Claude payload, then call GAR + var err error + currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults) + if err != nil { + log.Warnf("kiro/websearch: failed to inject tool results: %v", err) + e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) + break + } + + // Call GAR with modified Claude payload (full translation pipeline) + modifiedReq := req + modifiedReq.Payload = currentClaudePayload + kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if kiroErr != nil { + log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr) + e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) + break + } + + // Analyze response + analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks) + log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v, query: %s, toolUseId: %s", + iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse, analysis.WebSearchQuery, analysis.WebSearchToolUseId) + + if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations { + // Model wants another search + filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex) + for _, chunk := range filteredChunks { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: + } + } + + currentQuery = analysis.WebSearchQuery + currentToolUseId = analysis.WebSearchToolUseId + continue + } + + // Model returned final response — stream to client + for _, chunk := range kiroChunks { + if contentBlockIndex > 0 && len(chunk) > 0 { + adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex) + if !shouldForward { + continue + } + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}: + } + } else { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: chunk}: + } + } + } + log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches) + return + } + + log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations) + }() + + return out, nil +} + +// callKiroAndBuffer calls the Kiro API and buffers all response chunks. +// Returns the buffered chunks for analysis before forwarding to client. +func (e *KiroExecutor) callKiroAndBuffer( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) ([][]byte, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + log.Debugf("kiro/websearch GAR request: %d bytes", len(body)) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + tokenKey := "" + if auth != nil { + tokenKey = auth.ID + } + + kiroStream, err := e.executeStreamWithRetry( + ctx, auth, req, opts, accessToken, effectiveProfileArn, + nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, + ) + if err != nil { + return nil, err + } + + // Buffer all chunks + var chunks [][]byte + for chunk := range kiroStream { + if chunk.Err != nil { + return chunks, chunk.Err + } + if len(chunk.Payload) > 0 { + chunks = append(chunks, bytes.Clone(chunk.Payload)) + } + } + + log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks)) + + return chunks, nil +} + +// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation). +// Used in the web search loop where the payload is modified directly in Kiro format. +func (e *KiroExecutor) callKiroRawAndBuffer( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, + kiroBody []byte, +) ([][]byte, error) { + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + tokenKey := "" + if auth != nil { + tokenKey = auth.ID + } + log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody)) + + kiroFormat := sdktranslator.FromString("kiro") + kiroStream, err := e.executeStreamWithRetry( + ctx, auth, req, opts, accessToken, effectiveProfileArn, + nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, + ) + if err != nil { + return nil, err + } + + // Buffer all chunks + var chunks [][]byte + for chunk := range kiroStream { + if chunk.Err != nil { + return chunks, chunk.Err + } + if len(chunk.Payload) > 0 { + chunks = append(chunks, bytes.Clone(chunk.Payload)) + } + } + + log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks)) + + return chunks, nil +} + +// callKiroDirectStream creates a direct streaming channel to Kiro API without search. +func (e *KiroExecutor) callKiroDirectStream( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (<-chan cliproxyexecutor.StreamChunk, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + tokenKey := "" + if auth != nil { + tokenKey = auth.ID + } + + return e.executeStreamWithRetry( + ctx, auth, req, opts, accessToken, effectiveProfileArn, + nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, + ) +} + +// sendFallbackText sends a simple text response when the Kiro API fails during the search loop. +func (e *KiroExecutor) sendFallbackText( + ctx context.Context, + out chan<- cliproxyexecutor.StreamChunk, + contentBlockIndex int, + query string, + searchResults *kiroclaude.WebSearchResults, +) { + // Generate a simple text summary from search results + summary := kiroclaude.FormatSearchContextPrompt(query, searchResults) + + events := []kiroclaude.SseEvent{ + { + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": contentBlockIndex, + "content_block": map[string]interface{}{ + "type": "text", + "text": "", + }, + }, + }, + { + Event: "content_block_delta", + Data: map[string]interface{}{ + "type": "content_block_delta", + "index": contentBlockIndex, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": summary, + }, + }, + }, + { + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": contentBlockIndex, + }, + }, + } + + for _, event := range events { + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + } + } + + // Send message_delta with end_turn and message_stop + msgDelta := kiroclaude.SseEvent{ + Event: "message_delta", + Data: map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "output_tokens": len(summary) / 4, + }, + }, + } + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}: + } + + msgStop := kiroclaude.SseEvent{ + Event: "message_stop", + Data: map[string]interface{}{ + "type": "message_stop", + }, + } + select { + case <-ctx.Done(): + return + case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}: + } + +} + +// handleWebSearch handles web_search requests for non-streaming Execute path. +// Performs MCP search synchronously, injects results into the request payload, +// then calls the normal non-streaming Kiro API path which returns a proper +// Claude JSON response (not SSE chunks). +func (e *KiroExecutor) handleWebSearch( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") + // Fall through to normal non-streaming path + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint based on region + region := kiroDefaultRegion + if auth != nil && auth.Metadata != nil { + if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { + region = r + } + } + mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) + + // Step 1: Fetch/cache tool description (sync) + { + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + var authAttrs map[string]string + if auth != nil { + authAttrs = auth.Attributes + } + kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + } + + // Step 2: Perform MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(query) + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + var authAttrs map[string]string + if auth != nil { + authAttrs = auth.Attributes + } + handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query) + + // Step 3: Inject search tool_use + tool_result into Claude payload + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults) + if err != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry) + // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream + // to produce a proper Claude JSON response + modifiedReq := req + modifiedReq.Payload = modifiedPayload + + resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if err != nil { + return resp, err + } + + // Step 5: Inject server_tool_use + web_search_tool_result into response + // so Claude Code can display "Did X searches in Ys" + indicators := []kiroclaude.SearchIndicator{ + { + ToolUseID: currentToolUseId, + Query: query, + Results: searchResults, + }, + } + injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) + if injErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) + } else { + resp.Payload = injectedPayload + } + + return resp, nil +} + +// executeNonStreamFallback runs the standard non-streaming Execute path for a request. +// Used by handleWebSearch after injecting search results, or as a fallback. +func (e *KiroExecutor) executeNonStreamFallback( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + tokenKey := getTokenKey(auth) + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + var err error + defer reporter.trackFailure(ctx, &err) + + resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey) + return resp, err +} diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index b3742f22..790928f4 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -17,6 +17,8 @@ import ( "github.com/tidwall/gjson" ) +// remoteWebSearchDescription is a minimal fallback for when dynamic fetch from MCP tools/list hasn't completed yet. +const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information." // Kiro API request structs - field order determines JSON key order @@ -34,7 +36,6 @@ type KiroInferenceConfig struct { TopP float64 `json:"topP,omitempty"` } - // KiroConversationState holds the conversation context type KiroConversationState struct { ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field @@ -380,7 +381,6 @@ func hasThinkingTagInBody(body []byte) bool { return strings.Contains(bodyStr, "") || strings.Contains(bodyStr, "") } - // IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header. // Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking. func IsThinkingEnabledFromHeader(headers http.Header) bool { @@ -541,6 +541,18 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description) } + // Rename web_search → remote_web_search for Kiro API compatibility + if name == "web_search" { + name = "remote_web_search" + // Prefer dynamically fetched description, fall back to hardcoded constant + if cached := GetWebSearchDescription(); cached != "" { + description = cached + } else { + description = remoteWebSearchDescription + } + log.Debugf("kiro: renamed tool web_search → remote_web_search") + } + // Truncate long descriptions (individual tool limit) if len(description) > kirocommon.KiroMaxToolDescLen { truncLen := kirocommon.KiroMaxToolDescLen - 30 @@ -848,6 +860,11 @@ func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage }) } + // Rename web_search → remote_web_search to match convertClaudeToolsToKiro + if toolName == "web_search" { + toolName = "remote_web_search" + } + toolUses = append(toolUses, KiroToolUse{ ToolUseID: toolUseID, Name: toolName, diff --git a/internal/translator/kiro/claude/kiro_websearch.go b/internal/translator/kiro/claude/kiro_websearch.go new file mode 100644 index 00000000..25be730e --- /dev/null +++ b/internal/translator/kiro/claude/kiro_websearch.go @@ -0,0 +1,1169 @@ +// Package claude provides web search functionality for Kiro translator. +// This file implements detection and MCP request/response types for web search. +package claude + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API +type McpRequest struct { + ID string `json:"id"` + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params McpParams `json:"params"` +} + +// McpParams represents MCP request parameters +type McpParams struct { + Name string `json:"name"` + Arguments McpArguments `json:"arguments"` +} + +// McpArgumentsMeta represents the _meta field in MCP arguments +type McpArgumentsMeta struct { + IsValid bool `json:"_isValid"` + ActivePath []string `json:"_activePath"` + CompletedPaths [][]string `json:"_completedPaths"` +} + +// McpArguments represents MCP request arguments +type McpArguments struct { + Query string `json:"query"` + Meta *McpArgumentsMeta `json:"_meta,omitempty"` +} + +// McpResponse represents a JSON-RPC 2.0 response from Kiro MCP API +type McpResponse struct { + Error *McpError `json:"error,omitempty"` + ID string `json:"id"` + JSONRPC string `json:"jsonrpc"` + Result *McpResult `json:"result,omitempty"` +} + +// McpError represents an MCP error +type McpError struct { + Code *int `json:"code,omitempty"` + Message *string `json:"message,omitempty"` +} + +// McpResult represents MCP result +type McpResult struct { + Content []McpContent `json:"content"` + IsError bool `json:"isError"` +} + +// McpContent represents MCP content item +type McpContent struct { + ContentType string `json:"type"` + Text string `json:"text"` +} + +// WebSearchResults represents parsed search results +type WebSearchResults struct { + Results []WebSearchResult `json:"results"` + TotalResults *int `json:"totalResults,omitempty"` + Query *string `json:"query,omitempty"` + Error *string `json:"error,omitempty"` +} + +// WebSearchResult represents a single search result +type WebSearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Snippet *string `json:"snippet,omitempty"` + PublishedDate *int64 `json:"publishedDate,omitempty"` + ID *string `json:"id,omitempty"` + Domain *string `json:"domain,omitempty"` + MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"` + PublicDomain *bool `json:"publicDomain,omitempty"` +} + +// isWebSearchTool checks if a tool name or type indicates a web_search tool. +func isWebSearchTool(name, toolType string) bool { + return name == "web_search" || + strings.HasPrefix(toolType, "web_search") || + toolType == "web_search_20250305" +} + +// HasWebSearchTool checks if the request contains ONLY a web_search tool. +// Returns true only if tools array has exactly one tool named "web_search". +// Only intercept pure web_search requests (single-tool array). +func HasWebSearchTool(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return false + } + + toolsArray := tools.Array() + if len(toolsArray) != 1 { + return false + } + + // Check if the single tool is web_search + tool := toolsArray[0] + + // Check both name and type fields for web_search detection + name := strings.ToLower(tool.Get("name").String()) + toolType := strings.ToLower(tool.Get("type").String()) + + return isWebSearchTool(name, toolType) +} + +// ExtractSearchQuery extracts the search query from the request. +// Reads messages[0].content and removes "Perform a web search for the query: " prefix. +func ExtractSearchQuery(body []byte) string { + messages := gjson.GetBytes(body, "messages") + if !messages.IsArray() || len(messages.Array()) == 0 { + return "" + } + + firstMsg := messages.Array()[0] + content := firstMsg.Get("content") + + var text string + if content.IsArray() { + // Array format: [{"type": "text", "text": "..."}] + for _, block := range content.Array() { + if block.Get("type").String() == "text" { + text = block.Get("text").String() + break + } + } + } else { + // String format + text = content.String() + } + + // Remove prefix "Perform a web search for the query: " + const prefix = "Perform a web search for the query: " + if strings.HasPrefix(text, prefix) { + text = text[len(prefix):] + } + + return strings.TrimSpace(text) +} + +// generateRandomID8 generates an 8-character random lowercase alphanumeric string +func generateRandomID8() string { + u := uuid.New() + return strings.ToLower(strings.ReplaceAll(u.String(), "-", "")[:8]) +} + +// CreateMcpRequest creates an MCP request for web search. +// Returns (toolUseID, McpRequest) +// ID format: web_search_tooluse_{22 random}_{timestamp_millis}_{8 random} +func CreateMcpRequest(query string) (string, *McpRequest) { + random22 := GenerateToolUseID() + timestamp := time.Now().UnixMilli() + random8 := generateRandomID8() + + requestID := fmt.Sprintf("web_search_tooluse_%s_%d_%s", random22, timestamp, random8) + + // tool_use_id format: srvtoolu_{32 hex chars} + toolUseID := "srvtoolu_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:32] + + request := &McpRequest{ + ID: requestID, + JSONRPC: "2.0", + Method: "tools/call", + Params: McpParams{ + Name: "web_search", + Arguments: McpArguments{ + Query: query, + Meta: &McpArgumentsMeta{ + IsValid: true, + ActivePath: []string{"query"}, + CompletedPaths: [][]string{{"query"}}, + }, + }, + }, + } + + return toolUseID, request +} + +// GenerateMessageID generates a Claude-style message ID +func GenerateMessageID() string { + return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24] +} + +// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID) +func GenerateToolUseID() string { + return strings.ReplaceAll(uuid.New().String(), "-", "")[:22] +} + +// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools). +// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays. +func ContainsWebSearchTool(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return false + } + + for _, tool := range tools.Array() { + name := strings.ToLower(tool.Get("name").String()) + toolType := strings.ToLower(tool.Get("type").String()) + + if isWebSearchTool(name, toolType) { + return true + } + } + + return false +} + +// ReplaceWebSearchToolDescription replaces the web_search tool description with +// a minimal version that allows re-search without the restrictive "do not search +// non-coding topics" instruction from the original Kiro tools/list response. +// This keeps the tool available so the model can request additional searches. +func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return body, nil + } + + var updated []json.RawMessage + for _, tool := range tools.Array() { + name := strings.ToLower(tool.Get("name").String()) + toolType := strings.ToLower(tool.Get("type").String()) + + if isWebSearchTool(name, toolType) { + // Replace with a minimal web_search tool definition + minimalTool := map[string]interface{}{ + "name": "web_search", + "description": "Search the web for information. Use this when the previous search results are insufficient or when you need additional information on a different aspect of the query. Provide a refined or different search query.", + "input_schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "The search query to execute", + }, + }, + "required": []string{"query"}, + "additionalProperties": false, + }, + } + minimalJSON, err := json.Marshal(minimalTool) + if err != nil { + return body, fmt.Errorf("failed to marshal minimal tool: %w", err) + } + updated = append(updated, json.RawMessage(minimalJSON)) + } else { + updated = append(updated, json.RawMessage(tool.Raw)) + } + } + + updatedJSON, err := json.Marshal(updated) + if err != nil { + return body, fmt.Errorf("failed to marshal updated tools: %w", err) + } + result, err := sjson.SetRawBytes(body, "tools", updatedJSON) + if err != nil { + return body, fmt.Errorf("failed to set updated tools: %w", err) + } + + return result, nil +} + +// StripWebSearchTool removes web_search tool entries from the request's tools array. +// If the tools array becomes empty after removal, it is removed entirely. +func StripWebSearchTool(body []byte) ([]byte, error) { + tools := gjson.GetBytes(body, "tools") + if !tools.IsArray() { + return body, nil + } + + var filtered []json.RawMessage + for _, tool := range tools.Array() { + name := strings.ToLower(tool.Get("name").String()) + toolType := strings.ToLower(tool.Get("type").String()) + + if !isWebSearchTool(name, toolType) { + filtered = append(filtered, json.RawMessage(tool.Raw)) + } + } + + var result []byte + var err error + + if len(filtered) == 0 { + // Remove tools array entirely + result, err = sjson.DeleteBytes(body, "tools") + if err != nil { + return body, fmt.Errorf("failed to delete tools: %w", err) + } + } else { + // Replace with filtered array + filteredJSON, marshalErr := json.Marshal(filtered) + if marshalErr != nil { + return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr) + } + result, err = sjson.SetRawBytes(body, "tools", filteredJSON) + if err != nil { + return body, fmt.Errorf("failed to set filtered tools: %w", err) + } + } + + return result, nil +} + +// FormatSearchContextPrompt formats search results as a structured text block +// for injection into the system prompt. +func FormatSearchContextPrompt(query string, results *WebSearchResults) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("[Web Search Results for \"%s\"]\n", query)) + + if results != nil && len(results.Results) > 0 { + for i, r := range results.Results { + sb.WriteString(fmt.Sprintf("%d. %s - %s\n", i+1, r.Title, r.URL)) + if r.Snippet != nil && *r.Snippet != "" { + snippet := *r.Snippet + if len(snippet) > 500 { + snippet = snippet[:500] + "..." + } + sb.WriteString(fmt.Sprintf(" %s\n", snippet)) + } + } + } else { + sb.WriteString("No results found.\n") + } + + sb.WriteString("[End Web Search Results]") + return sb.String() +} + +// FormatToolResultText formats search results as JSON text for the toolResults content field. +// This matches the format observed in Kiro IDE HAR captures. +func FormatToolResultText(results *WebSearchResults) string { + if results == nil || len(results.Results) == 0 { + return "No search results found." + } + + text := fmt.Sprintf("Found %d search result(s):\n\n", len(results.Results)) + resultJSON, err := json.MarshalIndent(results.Results, "", " ") + if err != nil { + return text + "Error formatting results." + } + return text + string(resultJSON) +} + +// InjectToolResultsClaude modifies a Claude-format JSON payload to append +// tool_use (assistant) and tool_result (user) messages to the messages array. +// BuildKiroPayload correctly translates: +// - assistant tool_use → KiroAssistantResponseMessage.toolUses +// - user tool_result → KiroUserInputMessageContext.toolResults +// +// This produces the exact same GAR request format as the Kiro IDE (HAR captures). +// IMPORTANT: The web_search tool must remain in the "tools" array for this to work. +// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available. +func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) { + var payload map[string]interface{} + if err := json.Unmarshal(claudePayload, &payload); err != nil { + return claudePayload, fmt.Errorf("failed to parse claude payload: %w", err) + } + + messages, _ := payload["messages"].([]interface{}) + + // 1. Append assistant message with tool_use (matches HAR: assistantResponseMessage.toolUses) + assistantMsg := map[string]interface{}{ + "role": "assistant", + "content": []interface{}{ + map[string]interface{}{ + "type": "tool_use", + "id": toolUseId, + "name": "web_search", + "input": map[string]interface{}{"query": query}, + }, + }, + } + messages = append(messages, assistantMsg) + + // 2. Append user message with tool_result + search behavior instructions. + // NOTE: We embed search instructions HERE (not in system prompt) because + // BuildKiroPayload clears the system prompt when len(history) > 0, + // which is always true after injecting assistant + user messages. + now := time.Now() + searchGuidance := fmt.Sprintf(` +Current date: %s (%s) + +IMPORTANT: Evaluate the search results above carefully. If the results are: +- Mostly spam, SEO junk, or unrelated websites +- Missing actual information about the query topic +- Outdated or not matching the requested time frame + +Then you MUST use the web_search tool again with a refined query. Try: +- Rephrasing in English for better coverage +- Using more specific keywords +- Adding date context + +Do NOT apologize for bad results without first attempting a re-search. +`, now.Format("January 2, 2006"), now.Format("Monday")) + + userMsg := map[string]interface{}{ + "role": "user", + "content": []interface{}{ + map[string]interface{}{ + "type": "tool_result", + "tool_use_id": toolUseId, + "content": FormatToolResultText(results), + }, + map[string]interface{}{ + "type": "text", + "text": searchGuidance, + }, + }, + } + messages = append(messages, userMsg) + + payload["messages"] = messages + + result, err := json.Marshal(payload) + if err != nil { + return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err) + } + + log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, query=%s, messages=%d)", + toolUseId, query, len(messages)) + + return result, nil +} + +// InjectSearchIndicatorsInResponse prepends server_tool_use + web_search_tool_result +// content blocks into a non-streaming Claude JSON response. Claude Code counts +// server_tool_use blocks to display "Did X searches in Ys". +// +// Input response: {"content": [{"type":"text","text":"..."}], ...} +// Output response: {"content": [{"type":"server_tool_use",...}, {"type":"web_search_tool_result",...}, {"type":"text","text":"..."}], ...} +func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) { + if len(searches) == 0 { + return responsePayload, nil + } + + var resp map[string]interface{} + if err := json.Unmarshal(responsePayload, &resp); err != nil { + return responsePayload, fmt.Errorf("failed to parse response: %w", err) + } + + existingContent, _ := resp["content"].([]interface{}) + + // Build new content: search indicators first, then existing content + newContent := make([]interface{}, 0, len(searches)*2+len(existingContent)) + + for _, s := range searches { + // server_tool_use block + newContent = append(newContent, map[string]interface{}{ + "type": "server_tool_use", + "id": s.ToolUseID, + "name": "web_search", + "input": map[string]interface{}{"query": s.Query}, + }) + + // web_search_tool_result block + searchContent := make([]map[string]interface{}, 0) + if s.Results != nil { + for _, r := range s.Results.Results { + snippet := "" + if r.Snippet != nil { + snippet = *r.Snippet + } + searchContent = append(searchContent, map[string]interface{}{ + "type": "web_search_result", + "title": r.Title, + "url": r.URL, + "encrypted_content": snippet, + "page_age": nil, + }) + } + } + newContent = append(newContent, map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": s.ToolUseID, + "content": searchContent, + }) + } + + // Append existing content blocks + newContent = append(newContent, existingContent...) + resp["content"] = newContent + + result, err := json.Marshal(resp) + if err != nil { + return responsePayload, fmt.Errorf("failed to marshal response: %w", err) + } + + log.Infof("kiro/websearch: injected %d search indicator(s) into non-stream response", len(searches)) + return result, nil +} + +// SearchIndicator holds the data for one search operation to inject into a response. +type SearchIndicator struct { + ToolUseID string + Query string + Results *WebSearchResults +} + +// ══════════════════════════════════════════════════════════════════════════════ +// SSE Event Generation +// ══════════════════════════════════════════════════════════════════════════════ + +// SseEvent represents a Server-Sent Event +type SseEvent struct { + Event string + Data interface{} +} + +// ToSSEString converts the event to SSE wire format +func (e *SseEvent) ToSSEString() string { + dataBytes, _ := json.Marshal(e.Data) + return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes)) +} + +// GenerateWebSearchEvents generates the 11-event SSE sequence for web search. +// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json), +// content_block_stop, content_block_start(web_search_tool_result), content_block_stop, +// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop +func GenerateWebSearchEvents( + model string, + query string, + toolUseID string, + searchResults *WebSearchResults, + inputTokens int, +) []SseEvent { + events := make([]SseEvent, 0, 15) + messageID := GenerateMessageID() + + // 1. message_start + events = append(events, SseEvent{ + Event: "message_start", + Data: map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": model, + "content": []interface{}{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": inputTokens, + "output_tokens": 0, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + }, + }, + }, + }) + + // 2. content_block_start (server_tool_use) + events = append(events, SseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]interface{}{ + "id": toolUseID, + "type": "server_tool_use", + "name": "web_search", + "input": map[string]interface{}{}, + }, + }, + }) + + // 3. content_block_delta (input_json_delta) + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + events = append(events, SseEvent{ + Event: "content_block_delta", + Data: map[string]interface{}{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + }, + }) + + // 4. content_block_stop (server_tool_use) + events = append(events, SseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": 0, + }, + }) + + // 5. content_block_start (web_search_tool_result) + searchContent := make([]map[string]interface{}, 0) + if searchResults != nil { + for _, r := range searchResults.Results { + snippet := "" + if r.Snippet != nil { + snippet = *r.Snippet + } + searchContent = append(searchContent, map[string]interface{}{ + "type": "web_search_result", + "title": r.Title, + "url": r.URL, + "encrypted_content": snippet, + "page_age": nil, + }) + } + } + events = append(events, SseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": 1, + "content_block": map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": toolUseID, + "content": searchContent, + }, + }, + }) + + // 6. content_block_stop (web_search_tool_result) + events = append(events, SseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": 1, + }, + }) + + // 7. content_block_start (text) + events = append(events, SseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": 2, + "content_block": map[string]interface{}{ + "type": "text", + "text": "", + }, + }, + }) + + // 8. content_block_delta (text_delta) - generate search summary + summary := generateSearchSummary(query, searchResults) + + // Split text into chunks for streaming effect + chunkSize := 100 + runes := []rune(summary) + for i := 0; i < len(runes); i += chunkSize { + end := i + chunkSize + if end > len(runes) { + end = len(runes) + } + chunk := string(runes[i:end]) + events = append(events, SseEvent{ + Event: "content_block_delta", + Data: map[string]interface{}{ + "type": "content_block_delta", + "index": 2, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": chunk, + }, + }, + }) + } + + // 9. content_block_stop (text) + events = append(events, SseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": 2, + }, + }) + + // 10. message_delta + outputTokens := (len(summary) + 3) / 4 // Simple estimation + events = append(events, SseEvent{ + Event: "message_delta", + Data: map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "output_tokens": outputTokens, + }, + }, + }) + + // 11. message_stop + events = append(events, SseEvent{ + Event: "message_stop", + Data: map[string]interface{}{ + "type": "message_stop", + }, + }) + + return events +} + +// generateSearchSummary generates a text summary of search results +func generateSearchSummary(query string, results *WebSearchResults) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query)) + + if results != nil && len(results.Results) > 0 { + for i, r := range results.Results { + sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title)) + if r.Snippet != nil { + snippet := *r.Snippet + if len(snippet) > 200 { + snippet = snippet[:200] + "..." + } + sb.WriteString(fmt.Sprintf(" %s\n", snippet)) + } + sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL)) + } + } else { + sb.WriteString("No results found.\n") + } + + sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.") + + return sb.String() +} + +// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events +// (server_tool_use + web_search_tool_result) without text summary or message termination. +// These events trigger Claude Code's search indicator UI. +// The caller is responsible for sending message_start before and message_delta/stop after. +func GenerateSearchIndicatorEvents( + query string, + toolUseID string, + searchResults *WebSearchResults, + startIndex int, +) []SseEvent { + events := make([]SseEvent, 0, 4) + + // 1. content_block_start (server_tool_use) + events = append(events, SseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": startIndex, + "content_block": map[string]interface{}{ + "id": toolUseID, + "type": "server_tool_use", + "name": "web_search", + "input": map[string]interface{}{}, + }, + }, + }) + + // 2. content_block_delta (input_json_delta) + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + events = append(events, SseEvent{ + Event: "content_block_delta", + Data: map[string]interface{}{ + "type": "content_block_delta", + "index": startIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + }, + }) + + // 3. content_block_stop (server_tool_use) + events = append(events, SseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex, + }, + }) + + // 4. content_block_start (web_search_tool_result) + searchContent := make([]map[string]interface{}, 0) + if searchResults != nil { + for _, r := range searchResults.Results { + snippet := "" + if r.Snippet != nil { + snippet = *r.Snippet + } + searchContent = append(searchContent, map[string]interface{}{ + "type": "web_search_result", + "title": r.Title, + "url": r.URL, + "encrypted_content": snippet, + "page_age": nil, + }) + } + } + events = append(events, SseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": startIndex + 1, + "content_block": map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": toolUseID, + "content": searchContent, + }, + }, + }) + + // 5. content_block_stop (web_search_tool_result) + events = append(events, SseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex + 1, + }, + }) + + return events +} + +// ══════════════════════════════════════════════════════════════════════════════ +// Stream Analysis & Manipulation +// ══════════════════════════════════════════════════════════════════════════════ + +// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset. +// It also suppresses duplicate message_start events (returns shouldForward=false). +// This is used to combine search indicator events (indices 0,1) with Kiro model response events. +// +// The data parameter is a single SSE "data:" line payload (JSON). +// Returns: adjusted data, shouldForward (false = skip this event). +func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) { + if len(data) == 0 { + return data, true + } + + // Quick check: parse the JSON + var event map[string]interface{} + if err := json.Unmarshal(data, &event); err != nil { + // Not valid JSON, pass through + return data, true + } + + eventType, _ := event["type"].(string) + + // Suppress duplicate message_start events + if eventType == "message_start" { + return data, false + } + + // Adjust index for content_block events + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + offset + adjusted, err := json.Marshal(event) + if err != nil { + return data, true + } + return adjusted, true + } + } + + // Pass through all other events unchanged (message_delta, message_stop, ping, etc.) + return data, true +} + +// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs) +// and adjusts content block indices. Suppresses duplicate message_start events. +// Returns the adjusted chunk and whether it should be forwarded. +func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) { + chunkStr := string(chunk) + + // Fast path: if no "data:" prefix, pass through + if !strings.Contains(chunkStr, "data: ") { + return chunk, true + } + + var result strings.Builder + hasContent := false + + lines := strings.Split(chunkStr, "\n") + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + result.WriteString(line + "\n") + hasContent = true + continue + } + + adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset) + if !shouldForward { + // Skip this event and its preceding "event:" line + // Also skip the trailing empty line + continue + } + + result.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + dataPayload := strings.TrimPrefix(lines[i+1], "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err == nil { + if eventType, ok := event["type"].(string); ok && eventType == "message_start" { + // Skip both the event: and data: lines + i++ // skip the data: line too + continue + } + } + } + result.WriteString(line + "\n") + hasContent = true + } else { + result.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if !hasContent { + return nil, false + } + + return []byte(result.String()), true +} + +// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response. +type BufferedStreamResult struct { + // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use") + StopReason string + // WebSearchQuery is the extracted query if the model requested another web_search + WebSearchQuery string + // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults) + WebSearchToolUseId string + // HasWebSearchToolUse indicates whether the model requested web_search + HasWebSearchToolUse bool + // WebSearchToolUseIndex is the content_block index of the web_search tool_use + WebSearchToolUseIndex int +} + +// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use. +// This is used in the search loop to determine if the model wants another search round. +func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { + result := BufferedStreamResult{WebSearchToolUseIndex: -1} + + // Track tool use state across chunks + var currentToolName string + var currentToolIndex int = -1 + var toolInputBuilder strings.Builder + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + for _, line := range lines { + if !strings.HasPrefix(line, "data: ") { + continue + } + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + if dataPayload == "[DONE]" || dataPayload == "" { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "message_delta": + // Extract stop_reason from message_delta + if delta, ok := event["delta"].(map[string]interface{}); ok { + if sr, ok := delta["stop_reason"].(string); ok && sr != "" { + result.StopReason = sr + } + } + + case "content_block_start": + // Detect tool_use content blocks + if cb, ok := event["content_block"].(map[string]interface{}); ok { + if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" { + if name, ok := cb["name"].(string); ok { + currentToolName = strings.ToLower(name) + if idx, ok := event["index"].(float64); ok { + currentToolIndex = int(idx) + } + // Capture tool use ID for toolResults handshake + if id, ok := cb["id"].(string); ok { + result.WebSearchToolUseId = id + } + toolInputBuilder.Reset() + } + } + } + + case "content_block_delta": + // Accumulate tool input JSON + if currentToolName != "" { + if delta, ok := event["delta"].(map[string]interface{}); ok { + if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" { + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuilder.WriteString(partial) + } + } + } + } + + case "content_block_stop": + // Finalize tool use detection + if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" { + result.HasWebSearchToolUse = true + result.WebSearchToolUseIndex = currentToolIndex + // Extract query from accumulated input JSON + inputJSON := toolInputBuilder.String() + var input map[string]string + if err := json.Unmarshal([]byte(inputJSON), &input); err == nil { + if q, ok := input["query"]; ok { + result.WebSearchQuery = q + } + } + log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery) + } + currentToolName = "" + currentToolIndex = -1 + toolInputBuilder.Reset() + } + } + } + + return result +} + +// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use +// content blocks. This prevents the client from seeing "Tool use" prompts for web_search +// when the proxy is handling the search loop internally. +// Also suppresses message_start and message_delta/message_stop events since those +// are managed by the outer handleWebSearchStream. +func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte { + var filtered [][]byte + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + + var resultBuilder strings.Builder + hasContent := false + + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + // Skip [DONE] — the outer loop manages stream termination + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + resultBuilder.WriteString(line + "\n") + hasContent = true + continue + } + + eventType, _ := event["type"].(string) + + // Skip message_start (outer loop sends its own) + if eventType == "message_start" { + continue + } + + // Skip message_delta and message_stop (outer loop manages these) + if eventType == "message_delta" || eventType == "message_stop" { + continue + } + + // Check if this event belongs to the web_search tool_use block + if wsToolIndex >= 0 { + if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex { + // Skip events for the web_search tool_use block + continue + } + } + + // Apply index offset for remaining events + if indexOffset > 0 { + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + indexOffset + adjusted, err := json.Marshal(event) + if err == nil { + resultBuilder.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + continue + } + } + } + } + + resultBuilder.WriteString(line + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + nextData := strings.TrimPrefix(lines[i+1], "data: ") + nextData = strings.TrimSpace(nextData) + + var nextEvent map[string]interface{} + if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil { + nextType, _ := nextEvent["type"].(string) + if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" { + i++ // skip the data line + continue + } + if wsToolIndex >= 0 { + if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex { + i++ // skip the data line + continue + } + } + } + } + resultBuilder.WriteString(line + "\n") + hasContent = true + } else { + resultBuilder.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if hasContent { + filtered = append(filtered, []byte(resultBuilder.String())) + } + } + + return filtered +} diff --git a/internal/translator/kiro/claude/kiro_websearch_handler.go b/internal/translator/kiro/claude/kiro_websearch_handler.go new file mode 100644 index 00000000..c64d8eb9 --- /dev/null +++ b/internal/translator/kiro/claude/kiro_websearch_handler.go @@ -0,0 +1,270 @@ +// Package claude provides web search handler for Kiro translator. +// This file implements the MCP API call and response handling. +package claude + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +// Cached web_search tool description fetched from MCP tools/list. +// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure: +// - sync.Once prevents race conditions and deduplicates concurrent calls +// - On failure, a fresh sync.Once is swapped in to allow retry on next call +// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls +var ( + cachedToolDescription atomic.Value // stores string + toolDescOnce atomic.Pointer[sync.Once] + fallbackFpOnce sync.Once + fallbackFp *kiroauth.Fingerprint +) + +func init() { + toolDescOnce.Store(&sync.Once{}) +} + +// FetchToolDescription calls MCP tools/list to get the web_search tool description +// and caches it. Safe to call concurrently — only one goroutine fetches at a time. +// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. +// The httpClient parameter allows reusing a shared pooled HTTP client. +func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) { + toolDescOnce.Load().Do(func() { + handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs) + reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) + log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) + + req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody)) + if err != nil { + log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) + toolDescOnce.Store(&sync.Once{}) // allow retry + return + } + + // Reuse same headers as CallMcpAPI + handler.setMcpHeaders(req) + + resp, err := handler.HTTPClient.Do(req) + if err != nil { + log.Warnf("kiro/websearch: tools/list request failed: %v", err) + toolDescOnce.Store(&sync.Once{}) // allow retry + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil || resp.StatusCode != http.StatusOK { + log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) + toolDescOnce.Store(&sync.Once{}) // allow retry + return + } + log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) + + // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} + var result struct { + Result *struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { + log.Warnf("kiro/websearch: failed to parse tools/list response") + toolDescOnce.Store(&sync.Once{}) // allow retry + return + } + + for _, tool := range result.Result.Tools { + if tool.Name == "web_search" && tool.Description != "" { + cachedToolDescription.Store(tool.Description) + log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) + return // success — sync.Once stays "done", no more fetches + } + } + + // web_search tool not found in response + toolDescOnce.Store(&sync.Once{}) // allow retry + }) +} + +// GetWebSearchDescription returns the cached web_search tool description, +// or empty string if not yet fetched. Lock-free via atomic.Value. +func GetWebSearchDescription() string { + if v := cachedToolDescription.Load(); v != nil { + return v.(string) + } + return "" +} + +// WebSearchHandler handles web search requests via Kiro MCP API +type WebSearchHandler struct { + McpEndpoint string + HTTPClient *http.Client + AuthToken string + Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers + AuthAttrs map[string]string // optional, for custom headers from auth.Attributes +} + +// NewWebSearchHandler creates a new WebSearchHandler. +// If httpClient is nil, a default client with 30s timeout is used. +// If fingerprint is nil, a random one-off fingerprint is generated. +// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. +func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler { + if httpClient == nil { + httpClient = &http.Client{ + Timeout: 30 * time.Second, + } + } + if fp == nil { + // Use a shared fallback fingerprint for callers without token context + fallbackFpOnce.Do(func() { + mgr := kiroauth.NewFingerprintManager() + fallbackFp = mgr.GetFingerprint("mcp-fallback") + }) + fp = fallbackFp + } + return &WebSearchHandler{ + McpEndpoint: mcpEndpoint, + HTTPClient: httpClient, + AuthToken: authToken, + Fingerprint: fp, + AuthAttrs: authAttrs, + } +} + +// setMcpHeaders sets standard MCP API headers on the request, +// aligned with the GAR request pattern in kiro_executor.go. +func (h *WebSearchHandler) setMcpHeaders(req *http.Request) { + fp := h.Fingerprint + + // 1. Content-Type & Accept (aligned with GAR) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + + // 2. Kiro-specific headers (aligned with GAR) + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + + // 3. Dynamic fingerprint headers + req.Header.Set("User-Agent", fp.BuildUserAgent()) + req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) + + // 4. AWS SDK identifiers (casing aligned with GAR) + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + // 5. Authentication + req.Header.Set("Authorization", "Bearer "+h.AuthToken) + + // 6. Custom headers from auth attributes + util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs) +} + +// mcpMaxRetries is the maximum number of retries for MCP API calls. +const mcpMaxRetries = 2 + +// CallMcpAPI calls the Kiro MCP API with the given request. +// Includes retry logic with exponential backoff for retryable errors, +// aligned with the GAR request retry pattern. +func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) { + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCP request: %w", err) + } + log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody)) + + var lastErr error + for attempt := 0; attempt <= mcpMaxRetries; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 10*time.Second { + backoff = 10 * time.Second + } + log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) + time.Sleep(backoff) + } + + req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + h.setMcpHeaders(req) + + resp, err := h.HTTPClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("MCP API request failed: %w", err) + continue // network error → retry + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("failed to read MCP response: %w", err) + continue // read error → retry + } + log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) + + // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) + if resp.StatusCode >= 502 && resp.StatusCode <= 504 { + lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) + } + + var mcpResponse McpResponse + if err := json.Unmarshal(body, &mcpResponse); err != nil { + return nil, fmt.Errorf("failed to parse MCP response: %w", err) + } + + if mcpResponse.Error != nil { + code := -1 + if mcpResponse.Error.Code != nil { + code = *mcpResponse.Error.Code + } + msg := "Unknown error" + if mcpResponse.Error.Message != nil { + msg = *mcpResponse.Error.Message + } + return nil, fmt.Errorf("MCP error %d: %s", code, msg) + } + + return &mcpResponse, nil + } + + return nil, lastErr +} + +// ParseSearchResults extracts WebSearchResults from MCP response +func ParseSearchResults(response *McpResponse) *WebSearchResults { + if response == nil || response.Result == nil || len(response.Result.Content) == 0 { + return nil + } + + content := response.Result.Content[0] + if content.ContentType != "text" { + return nil + } + + var results WebSearchResults + if err := json.Unmarshal([]byte(content.Text), &results); err != nil { + log.Warnf("kiro/websearch: failed to parse search results: %v", err) + return nil + } + + return &results +} From 09b19f5c4ee58e552bb946c65ef6b38f40e01170 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Wed, 11 Feb 2026 00:23:05 +0800 Subject: [PATCH 08/15] fix(kiro): filter orphaned tool_results from compacted conversations --- .../kiro/claude/kiro_claude_request.go | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index 425d9ae2..c1c93f69 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -654,6 +654,57 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto } } + // POST-PROCESSING: Remove orphaned tool_results that have no matching tool_use + // in any assistant message. This happens when Claude Code compaction truncates + // the conversation and removes the assistant message containing the tool_use, + // but keeps the user message with the corresponding tool_result. + // Without this fix, Kiro API returns "Improperly formed request". + validToolUseIDs := make(map[string]bool) + for _, h := range history { + if h.AssistantResponseMessage != nil { + for _, tu := range h.AssistantResponseMessage.ToolUses { + validToolUseIDs[tu.ToolUseID] = true + } + } + } + + // Filter orphaned tool results from history user messages + for i, h := range history { + if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil { + ctx := h.UserInputMessage.UserInputMessageContext + if len(ctx.ToolResults) > 0 { + filtered := make([]KiroToolResult, 0, len(ctx.ToolResults)) + for _, tr := range ctx.ToolResults { + if validToolUseIDs[tr.ToolUseID] { + filtered = append(filtered, tr) + } else { + log.Debugf("kiro: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID) + } + } + ctx.ToolResults = filtered + if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 { + h.UserInputMessage.UserInputMessageContext = nil + } + } + } + } + + // Filter orphaned tool results from current message + if len(currentToolResults) > 0 { + filtered := make([]KiroToolResult, 0, len(currentToolResults)) + for _, tr := range currentToolResults { + if validToolUseIDs[tr.ToolUseID] { + filtered = append(filtered, tr) + } else { + log.Debugf("kiro: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID) + } + } + if len(filtered) != len(currentToolResults) { + log.Infof("kiro: dropped %d orphaned tool_result(s) from currentMessage (compaction artifact)", len(currentToolResults)-len(filtered)) + } + currentToolResults = filtered + } + return history, currentUserMsg, currentToolResults } From bcd2208b513d4ee115f9e96556e78c4c60d524c2 Mon Sep 17 00:00:00 2001 From: Anilcan Cakir Date: Tue, 10 Feb 2026 23:34:19 +0300 Subject: [PATCH 09/15] fix(auth): strip model suffix in GitHub Copilot executor before upstream call GitHub Copilot API rejects model names with suffixes (e.g. claude-opus-4.6(medium)). The OAuthModelAlias resolution correctly maps aliases like 'opus(medium)' to 'claude-opus-4.6(medium)' preserving the suffix, but the executor must strip the suffix before sending to the upstream API since Copilot only accepts bare model names. Update normalizeModel in github_copilot_executor to strip suffixes using thinking.ParseSuffix, matching the pattern used by other executors. Also add test coverage for: - OAuthModelAliasChannel github-copilot and kiro channel resolution - Suffix preservation in alias resolution for github-copilot - normalizeModel suffix stripping in github_copilot_executor --- .../executor/github_copilot_executor.go | 12 +++-- .../executor/github_copilot_executor_test.go | 54 +++++++++++++++++++ sdk/cliproxy/auth/oauth_model_alias_test.go | 36 +++++++++++++ 3 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 internal/runtime/executor/github_copilot_executor_test.go diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index b43e1909..3681faf8 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -471,9 +472,14 @@ func detectVisionContent(body []byte) bool { return false } -// normalizeModel is a no-op as GitHub Copilot accepts model names directly. -// Model mapping should be done at the registry level if needed. -func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte { +// normalizeModel strips the suffix (e.g. "(medium)") from the model name +// before sending to GitHub Copilot, as the upstream API does not accept +// suffixed model identifiers. +func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte { + baseModel := thinking.ParseSuffix(model).ModelName + if baseModel != model { + body, _ = sjson.SetBytes(body, "model", baseModel) + } return body } diff --git a/internal/runtime/executor/github_copilot_executor_test.go b/internal/runtime/executor/github_copilot_executor_test.go new file mode 100644 index 00000000..ef077fd6 --- /dev/null +++ b/internal/runtime/executor/github_copilot_executor_test.go @@ -0,0 +1,54 @@ +package executor + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + model string + wantModel string + }{ + { + name: "suffix stripped", + model: "claude-opus-4.6(medium)", + wantModel: "claude-opus-4.6", + }, + { + name: "no suffix unchanged", + model: "claude-opus-4.6", + wantModel: "claude-opus-4.6", + }, + { + name: "different suffix stripped", + model: "gpt-4o(high)", + wantModel: "gpt-4o", + }, + { + name: "numeric suffix stripped", + model: "gemini-2.5-pro(8192)", + wantModel: "gemini-2.5-pro", + }, + } + + e := &GitHubCopilotExecutor{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + body := []byte(`{"model":"` + tt.model + `","messages":[]}`) + got := e.normalizeModel(tt.model, body) + + gotModel := gjson.GetBytes(got, "model").String() + if gotModel != tt.wantModel { + t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel) + } + }) + } +} diff --git a/sdk/cliproxy/auth/oauth_model_alias_test.go b/sdk/cliproxy/auth/oauth_model_alias_test.go index 2ff4000f..e12b6597 100644 --- a/sdk/cliproxy/auth/oauth_model_alias_test.go +++ b/sdk/cliproxy/auth/oauth_model_alias_test.go @@ -79,6 +79,24 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) { input: "gemini-2.5-pro(none)", want: "gemini-2.5-pro-exp-03-25(none)", }, + { + name: "github-copilot suffix preserved", + aliases: map[string][]internalconfig.OAuthModelAlias{ + "github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}}, + }, + channel: "github-copilot", + input: "opus(medium)", + want: "claude-opus-4.6(medium)", + }, + { + name: "github-copilot no suffix", + aliases: map[string][]internalconfig.OAuthModelAlias{ + "github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}}, + }, + channel: "github-copilot", + input: "opus", + want: "claude-opus-4.6", + }, { name: "kimi suffix preserved", aliases: map[string][]internalconfig.OAuthModelAlias{ @@ -174,6 +192,8 @@ func createAuthForChannel(channel string) *Auth { return &Auth{Provider: "kimi"} case "kiro": return &Auth{Provider: "kiro"} + case "github-copilot": + return &Auth{Provider: "github-copilot"} default: return &Auth{Provider: channel} } @@ -187,6 +207,22 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) { } } +func TestOAuthModelAliasChannel_GitHubCopilot(t *testing.T) { + t.Parallel() + + if got := OAuthModelAliasChannel("github-copilot", ""); got != "github-copilot" { + t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "github-copilot") + } +} + +func TestOAuthModelAliasChannel_Kiro(t *testing.T) { + t.Parallel() + + if got := OAuthModelAliasChannel("kiro", ""); got != "kiro" { + t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kiro") + } +} + func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) { t.Parallel() From 5ed2133ff9a96f5e51796ed2df6867a494a01bea Mon Sep 17 00:00:00 2001 From: RGBadmin Date: Wed, 11 Feb 2026 15:21:12 +0800 Subject: [PATCH 10/15] feat: add per-account excluded_models and priority parsing --- internal/watcher/synthesizer/file.go | 61 +++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index c80ebc66..20b2faec 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" "time" @@ -92,6 +93,9 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e status = coreauth.StatusDisabled } + // Read per-account excluded models from the OAuth JSON file + perAccountExcluded := extractExcludedModelsFromMetadata(metadata) + a := &coreauth.Auth{ ID: id, Provider: provider, @@ -108,11 +112,22 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e CreatedAt: now, UpdatedAt: now, } - ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth") + // Read priority from auth file + if rawPriority, ok := metadata["priority"]; ok { + switch v := rawPriority.(type) { + case float64: + a.Attributes["priority"] = strconv.Itoa(int(v)) + case string: + if _, err := strconv.Atoi(v); err == nil { + a.Attributes["priority"] = v + } + } + } + ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") if provider == "gemini-cli" { if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { for _, v := range virtuals { - ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth") + ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth") } out = append(out, a) out = append(out, virtuals...) @@ -167,6 +182,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an if authPath != "" { attrs["path"] = authPath } + // Propagate priority from primary auth to virtual auths + if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" { + attrs["priority"] = priorityVal + } metadataCopy := map[string]any{ "email": email, "project_id": projectID, @@ -239,3 +258,41 @@ func buildGeminiVirtualID(baseID, projectID string) string { replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_") return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project)) } + +// extractExcludedModelsFromMetadata reads per-account excluded models from the OAuth JSON metadata. +// Supports both "excluded_models" and "excluded-models" keys, and accepts both []string and []interface{}. +func extractExcludedModelsFromMetadata(metadata map[string]any) []string { + if metadata == nil { + return nil + } + // Try both key formats + raw, ok := metadata["excluded_models"] + if !ok { + raw, ok = metadata["excluded-models"] + } + if !ok || raw == nil { + return nil + } + switch v := raw.(type) { + case []string: + result := make([]string, 0, len(v)) + for _, s := range v { + if trimmed := strings.TrimSpace(s); trimmed != "" { + result = append(result, trimmed) + } + } + return result + case []interface{}: + result := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + if trimmed := strings.TrimSpace(s); trimmed != "" { + result = append(result, trimmed) + } + } + } + return result + default: + return nil + } +} From b93026d83a8da573f4871c8a483287d2ea8c02d6 Mon Sep 17 00:00:00 2001 From: RGBadmin Date: Wed, 11 Feb 2026 15:21:15 +0800 Subject: [PATCH 11/15] feat: merge per-account excluded_models with global config --- internal/watcher/synthesizer/helpers.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internal/watcher/synthesizer/helpers.go b/internal/watcher/synthesizer/helpers.go index 621f3600..102dc77e 100644 --- a/internal/watcher/synthesizer/helpers.go +++ b/internal/watcher/synthesizer/helpers.go @@ -53,6 +53,8 @@ func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) // ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry. // It computes a hash of excluded models and sets the auth_kind attribute. +// For OAuth entries, perKey (from the JSON file's excluded-models field) is merged +// with the global oauth-excluded-models config for the provider. func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) { if auth == nil || cfg == nil { return @@ -72,9 +74,13 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey } if authKindKey == "apikey" { add(perKey) - } else if cfg.OAuthExcludedModels != nil { - providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) - add(cfg.OAuthExcludedModels[providerKey]) + } else { + // For OAuth: merge per-account excluded models with global provider-level exclusions + add(perKey) + if cfg.OAuthExcludedModels != nil { + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + add(cfg.OAuthExcludedModels[providerKey]) + } } combined := make([]string, 0, len(seen)) for k := range seen { @@ -88,6 +94,10 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey if hash != "" { auth.Attributes["excluded_models_hash"] = hash } + // Store the combined excluded models list so that routing can read it at runtime + if len(combined) > 0 { + auth.Attributes["excluded_models"] = strings.Join(combined, ",") + } if authKind != "" { auth.Attributes["auth_kind"] = authKind } From 4cbcc835d1e7fc616a23a6d516e9cc68b1282d40 Mon Sep 17 00:00:00 2001 From: RGBadmin Date: Wed, 11 Feb 2026 15:21:19 +0800 Subject: [PATCH 12/15] feat: read per-account excluded_models at routing time --- sdk/cliproxy/service.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 0ae05c08..b77de8c6 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -740,6 +740,26 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { provider = "openai-compatibility" } excluded := s.oauthExcludedModels(provider, authKind) + // Merge per-account excluded models from auth attributes (set by synthesizer) + if a.Attributes != nil { + if perAccount := strings.TrimSpace(a.Attributes["excluded_models"]); perAccount != "" { + parts := strings.Split(perAccount, ",") + seen := make(map[string]struct{}, len(excluded)+len(parts)) + for _, e := range excluded { + seen[strings.ToLower(strings.TrimSpace(e))] = struct{}{} + } + for _, p := range parts { + seen[strings.ToLower(strings.TrimSpace(p))] = struct{}{} + } + merged := make([]string, 0, len(seen)) + for k := range seen { + if k != "" { + merged = append(merged, k) + } + } + excluded = merged + } + } var models []*ModelInfo switch provider { case "gemini": From bf1634bda0fe3388a50e00ac227ad653639ec7e5 Mon Sep 17 00:00:00 2001 From: RGBadmin Date: Wed, 11 Feb 2026 15:57:15 +0800 Subject: [PATCH 13/15] refactor: simplify per-account excluded_models merge in routing --- sdk/cliproxy/service.go | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index b77de8c6..536329b5 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -740,24 +740,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { provider = "openai-compatibility" } excluded := s.oauthExcludedModels(provider, authKind) - // Merge per-account excluded models from auth attributes (set by synthesizer) + // The synthesizer pre-merges per-account and global exclusions into the "excluded_models" attribute. + // If this attribute is present, it represents the complete list of exclusions and overrides the global config. if a.Attributes != nil { - if perAccount := strings.TrimSpace(a.Attributes["excluded_models"]); perAccount != "" { - parts := strings.Split(perAccount, ",") - seen := make(map[string]struct{}, len(excluded)+len(parts)) - for _, e := range excluded { - seen[strings.ToLower(strings.TrimSpace(e))] = struct{}{} - } - for _, p := range parts { - seen[strings.ToLower(strings.TrimSpace(p))] = struct{}{} - } - merged := make([]string, 0, len(seen)) - for k := range seen { - if k != "" { - merged = append(merged, k) - } - } - excluded = merged + if val, ok := a.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" { + excluded = strings.Split(val, ",") } } var models []*ModelInfo From dc279de443f60594c01efae29011ea59503f6aef Mon Sep 17 00:00:00 2001 From: RGBadmin Date: Wed, 11 Feb 2026 15:57:16 +0800 Subject: [PATCH 14/15] refactor: reduce code duplication in extractExcludedModelsFromMetadata --- internal/watcher/synthesizer/file.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 20b2faec..8f4ec6da 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -273,26 +273,25 @@ func extractExcludedModelsFromMetadata(metadata map[string]any) []string { if !ok || raw == nil { return nil } + var stringSlice []string switch v := raw.(type) { case []string: - result := make([]string, 0, len(v)) - for _, s := range v { - if trimmed := strings.TrimSpace(s); trimmed != "" { - result = append(result, trimmed) - } - } - return result + stringSlice = v case []interface{}: - result := make([]string, 0, len(v)) + stringSlice = make([]string, 0, len(v)) for _, item := range v { if s, ok := item.(string); ok { - if trimmed := strings.TrimSpace(s); trimmed != "" { - result = append(result, trimmed) - } + stringSlice = append(stringSlice, s) } } - return result default: return nil } + result := make([]string, 0, len(stringSlice)) + for _, s := range stringSlice { + if trimmed := strings.TrimSpace(s); trimmed != "" { + result = append(result, trimmed) + } + } + return result } From 4c133d3ea9dc77b740b5b454d7bc582a1045b37b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 11 Feb 2026 20:35:13 +0800 Subject: [PATCH 15/15] test(sdk/watcher): add tests for excluded models merging and priority parsing logic - Added unit tests for combining OAuth excluded models across global and attribute-specific scopes. - Implemented priority attribute parsing with support for different formats and trimming. --- internal/watcher/synthesizer/file.go | 5 +- internal/watcher/synthesizer/file_test.go | 118 +++++++++++++++++++ internal/watcher/synthesizer/helpers_test.go | 25 ++++ sdk/cliproxy/service_excluded_models_test.go | 65 ++++++++++ 4 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 sdk/cliproxy/service_excluded_models_test.go diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 8f4ec6da..4e053117 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -118,8 +118,9 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e case float64: a.Attributes["priority"] = strconv.Itoa(int(v)) case string: - if _, err := strconv.Atoi(v); err == nil { - a.Attributes["priority"] = v + priority := strings.TrimSpace(v) + if _, errAtoi := strconv.Atoi(priority); errAtoi == nil { + a.Attributes["priority"] = priority } } } diff --git a/internal/watcher/synthesizer/file_test.go b/internal/watcher/synthesizer/file_test.go index 93025fba..105d9207 100644 --- a/internal/watcher/synthesizer/file_test.go +++ b/internal/watcher/synthesizer/file_test.go @@ -297,6 +297,117 @@ func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) { } } +func TestFileSynthesizer_Synthesize_PriorityParsing(t *testing.T) { + tests := []struct { + name string + priority any + want string + hasValue bool + }{ + { + name: "string with spaces", + priority: " 10 ", + want: "10", + hasValue: true, + }, + { + name: "number", + priority: 8, + want: "8", + hasValue: true, + }, + { + name: "invalid string", + priority: "1x", + hasValue: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + authData := map[string]any{ + "type": "claude", + "priority": tt.priority, + } + data, _ := json.Marshal(authData) + errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644) + if errWriteFile != nil { + t.Fatalf("failed to write auth file: %v", errWriteFile) + } + + synth := NewFileSynthesizer() + ctx := &SynthesisContext{ + Config: &config.Config{}, + AuthDir: tempDir, + Now: time.Now(), + IDGenerator: NewStableIDGenerator(), + } + + auths, errSynthesize := synth.Synthesize(ctx) + if errSynthesize != nil { + t.Fatalf("unexpected error: %v", errSynthesize) + } + if len(auths) != 1 { + t.Fatalf("expected 1 auth, got %d", len(auths)) + } + + value, ok := auths[0].Attributes["priority"] + if tt.hasValue { + if !ok { + t.Fatal("expected priority attribute to be set") + } + if value != tt.want { + t.Fatalf("expected priority %q, got %q", tt.want, value) + } + return + } + if ok { + t.Fatalf("expected priority attribute to be absent, got %q", value) + } + }) + } +} + +func TestFileSynthesizer_Synthesize_OAuthExcludedModelsMerged(t *testing.T) { + tempDir := t.TempDir() + authData := map[string]any{ + "type": "claude", + "excluded_models": []string{"custom-model", "MODEL-B"}, + } + data, _ := json.Marshal(authData) + errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644) + if errWriteFile != nil { + t.Fatalf("failed to write auth file: %v", errWriteFile) + } + + synth := NewFileSynthesizer() + ctx := &SynthesisContext{ + Config: &config.Config{ + OAuthExcludedModels: map[string][]string{ + "claude": {"shared", "model-b"}, + }, + }, + AuthDir: tempDir, + Now: time.Now(), + IDGenerator: NewStableIDGenerator(), + } + + auths, errSynthesize := synth.Synthesize(ctx) + if errSynthesize != nil { + t.Fatalf("unexpected error: %v", errSynthesize) + } + if len(auths) != 1 { + t.Fatalf("expected 1 auth, got %d", len(auths)) + } + + got := auths[0].Attributes["excluded_models"] + want := "custom-model,model-b,shared" + if got != want { + t.Fatalf("expected excluded_models %q, got %q", want, got) + } +} + func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) { now := time.Now() @@ -533,6 +644,7 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) { "type": "gemini", "email": "multi@example.com", "project_id": "project-a, project-b, project-c", + "priority": " 10 ", } data, _ := json.Marshal(authData) err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644) @@ -565,6 +677,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) { if primary.Status != coreauth.StatusDisabled { t.Errorf("expected primary status disabled, got %s", primary.Status) } + if gotPriority := primary.Attributes["priority"]; gotPriority != "10" { + t.Errorf("expected primary priority 10, got %q", gotPriority) + } // Remaining auths should be virtuals for i := 1; i < 4; i++ { @@ -575,6 +690,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) { if v.Attributes["gemini_virtual_parent"] != primary.ID { t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"]) } + if gotPriority := v.Attributes["priority"]; gotPriority != "10" { + t.Errorf("expected virtual %d priority 10, got %q", i, gotPriority) + } } } diff --git a/internal/watcher/synthesizer/helpers_test.go b/internal/watcher/synthesizer/helpers_test.go index 229c75bc..46b9c8a0 100644 --- a/internal/watcher/synthesizer/helpers_test.go +++ b/internal/watcher/synthesizer/helpers_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) @@ -200,6 +201,30 @@ func TestApplyAuthExcludedModelsMeta(t *testing.T) { } } +func TestApplyAuthExcludedModelsMeta_OAuthMergeWritesCombinedModels(t *testing.T) { + auth := &coreauth.Auth{ + Provider: "claude", + Attributes: make(map[string]string), + } + cfg := &config.Config{ + OAuthExcludedModels: map[string][]string{ + "claude": {"global-a", "shared"}, + }, + } + + ApplyAuthExcludedModelsMeta(auth, cfg, []string{"per", "SHARED"}, "oauth") + + const wantCombined = "global-a,per,shared" + if gotCombined := auth.Attributes["excluded_models"]; gotCombined != wantCombined { + t.Fatalf("expected excluded_models=%q, got %q", wantCombined, gotCombined) + } + + expectedHash := diff.ComputeExcludedModelsHash([]string{"global-a", "per", "shared"}) + if gotHash := auth.Attributes["excluded_models_hash"]; gotHash != expectedHash { + t.Fatalf("expected excluded_models_hash=%q, got %q", expectedHash, gotHash) + } +} + func TestAddConfigHeadersToAttrs(t *testing.T) { tests := []struct { name string diff --git a/sdk/cliproxy/service_excluded_models_test.go b/sdk/cliproxy/service_excluded_models_test.go new file mode 100644 index 00000000..198a5bed --- /dev/null +++ b/sdk/cliproxy/service_excluded_models_test.go @@ -0,0 +1,65 @@ +package cliproxy + +import ( + "strings" + "testing" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T) { + service := &Service{ + cfg: &config.Config{ + OAuthExcludedModels: map[string][]string{ + "gemini-cli": {"gemini-2.5-pro"}, + }, + }, + } + auth := &coreauth.Auth{ + ID: "auth-gemini-cli", + Provider: "gemini-cli", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + "excluded_models": "gemini-2.5-flash", + }, + } + + registry := GlobalModelRegistry() + registry.UnregisterClient(auth.ID) + t.Cleanup(func() { + registry.UnregisterClient(auth.ID) + }) + + service.registerModelsForAuth(auth) + + models := registry.GetAvailableModelsByProvider("gemini-cli") + if len(models) == 0 { + t.Fatal("expected gemini-cli models to be registered") + } + + for _, model := range models { + if model == nil { + continue + } + modelID := strings.TrimSpace(model.ID) + if strings.EqualFold(modelID, "gemini-2.5-flash") { + t.Fatalf("expected model %q to be excluded by auth attribute", modelID) + } + } + + seenGlobalExcluded := false + for _, model := range models { + if model == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(model.ID), "gemini-2.5-pro") { + seenGlobalExcluded = true + break + } + } + if !seenGlobalExcluded { + t.Fatal("expected global excluded model to be present when attribute override is set") + } +}