diff --git a/.gitignore b/.gitignore index 2b9c215a..02493d24 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ cliproxy # Configuration config.yaml .env - +.mcp.json # Generated content bin/* logs/* diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 5cca03ba..0153a381 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -796,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { c.JSON(404, gin.H{"error": "channel not found"}) return } - delete(h.cfg.OAuthModelAlias, channel) - if len(h.cfg.OAuthModelAlias) == 0 { - h.cfg.OAuthModelAlias = nil - } + // Set to nil instead of deleting the key so that the "explicitly disabled" + // marker survives config reload and prevents SanitizeOAuthModelAlias from + // re-injecting default aliases (fixes #222). + h.cfg.OAuthModelAlias[channel] = nil h.persist(c) } diff --git a/internal/config/config.go b/internal/config/config.go index 50b3cbd5..88e1c605 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -767,7 +767,13 @@ func (cfg *Config) SanitizeOAuthModelAlias() { out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias)) for rawChannel, aliases := range cfg.OAuthModelAlias { channel := strings.ToLower(strings.TrimSpace(rawChannel)) - if channel == "" || len(aliases) == 0 { + if channel == "" { + continue + } + // Preserve channels that were explicitly set to empty/nil – they act + // as "disabled" markers so default injection won't re-add them (#222). + if len(aliases) == 0 { + out[channel] = nil continue } seenAlias := make(map[string]struct{}, len(aliases)) diff --git a/internal/config/oauth_model_alias_test.go b/internal/config/oauth_model_alias_test.go index 7497eec8..5cf05502 100644 --- a/internal/config/oauth_model_alias_test.go +++ b/internal/config/oauth_model_alias_test.go @@ -128,6 +128,50 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) { } } +func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) { + // When user explicitly deletes kiro aliases (key exists with nil value), + // defaults should NOT be re-injected on subsequent sanitize calls (#222). + cfg := &Config{ + OAuthModelAlias: map[string][]OAuthModelAlias{ + "kiro": nil, // explicitly deleted + "codex": {{Name: "gpt-5", Alias: "g5"}}, + }, + } + + cfg.SanitizeOAuthModelAlias() + + kiroAliases := cfg.OAuthModelAlias["kiro"] + if len(kiroAliases) != 0 { + t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases)) + } + // The key itself must still be present to prevent re-injection on next reload + if _, exists := cfg.OAuthModelAlias["kiro"]; !exists { + t.Fatal("expected kiro key to be preserved as nil marker after sanitization") + } + // Other channels should be unaffected + if len(cfg.OAuthModelAlias["codex"]) != 1 { + t.Fatal("expected codex aliases to be preserved") + } +} + +func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) { + // Same as above but with empty slice instead of nil (PUT with empty body). + cfg := &Config{ + OAuthModelAlias: map[string][]OAuthModelAlias{ + "kiro": {}, // explicitly set to empty + }, + } + + cfg.SanitizeOAuthModelAlias() + + if len(cfg.OAuthModelAlias["kiro"]) != 0 { + t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"])) + } + if _, exists := cfg.OAuthModelAlias["kiro"]; !exists { + t.Fatal("expected kiro key to be preserved") + } +} + func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) { // When OAuthModelAlias is nil, kiro defaults should still be injected cfg := &Config{} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 24765740..da82b8d0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { exec := &AntigravityExecutor{cfg: cfg} token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth) - if errToken != nil || token == "" { + if errToken != nil { + log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken) + return nil + } + if token == "" { + log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID) return nil } if updatedAuth != nil { @@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c modelsURL := baseURL + antigravityModelsPath httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`))) if errReq != nil { + log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq) return nil } httpReq.Header.Set("Content-Type", "application/json") @@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { + log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo) return nil } if idx+1 < len(baseURLs) { log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo) return nil } @@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead) return nil } if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { @@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes)) return nil } result := gjson.GetBytes(bodyBytes, "models") if !result.Exists() { + log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes)) return nil } diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 09066186..695680e8 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -110,7 +110,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -133,6 +133,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. return resp, err } + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", false) @@ -209,7 +215,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. } var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + converted := "" + if useResponses && from.String() == "claude" { + converted = translateGitHubCopilotResponsesNonStreamToClaude(data) + } else { + converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + } resp = cliproxyexecutor.Response{Payload: []byte(converted)} reporter.ensurePublished(ctx) return resp, nil @@ -226,7 +237,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -249,6 +260,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox return nil, err } + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", true) @@ -349,7 +366,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox } } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + var chunks []string + if useResponses && from.String() == "claude" { + chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m) + } else { + chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + } for i := range chunks { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} } @@ -503,8 +525,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte return body } -func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool { - return sourceFormat.String() == "openai-response" +func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool { + if sourceFormat.String() == "openai-response" { + return true + } + baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName) + return strings.Contains(baseModel, "codex") } // flattenAssistantContent converts assistant message content from array format @@ -539,6 +565,411 @@ func flattenAssistantContent(body []byte) []byte { return result } +func normalizeGitHubCopilotChatTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + if tool.Get("type").String() != "function" { + continue + } + filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func normalizeGitHubCopilotResponsesInput(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if input.Exists() { + if input.Type == gjson.String { + return body + } + inputString := input.Raw + if input.Type != gjson.JSON { + inputString = input.String() + } + body, _ = sjson.SetBytes(body, "input", inputString) + return body + } + + var parts []string + if system := gjson.GetBytes(body, "system"); system.Exists() { + if text := strings.TrimSpace(collectTextFromNode(system)); text != "" { + parts = append(parts, text) + } + } + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + for _, msg := range messages.Array() { + if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" { + parts = append(parts, text) + } + } + } + body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n")) + return body +} + +func normalizeGitHubCopilotResponsesTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + // Accept OpenAI format (type="function") and Claude format + // (no type field, but has top-level name + input_schema). + if toolType != "" && toolType != "function" { + continue + } + name := tool.Get("name").String() + if name == "" { + name = tool.Get("function.name").String() + } + if name == "" { + continue + } + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + if desc := tool.Get("description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } else if desc = tool.Get("function.description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } + if params := tool.Get("parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("function.parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("input_schema"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } + filtered, _ = sjson.SetRaw(filtered, "-1", normalized) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + default: + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body + } + } + if toolChoice.Type == gjson.JSON { + choiceType := toolChoice.Get("type").String() + if choiceType == "function" { + name := toolChoice.Get("name").String() + if name == "" { + name = toolChoice.Get("function.name").String() + } + if name != "" { + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized)) + return body + } + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func collectTextFromNode(node gjson.Result) string { + if !node.Exists() { + return "" + } + if node.Type == gjson.String { + return node.String() + } + if node.IsArray() { + var parts []string + for _, item := range node.Array() { + if item.Type == gjson.String { + if text := item.String(); text != "" { + parts = append(parts, text) + } + continue + } + if text := item.Get("text").String(); text != "" { + parts = append(parts, text) + continue + } + if nested := collectTextFromNode(item.Get("content")); nested != "" { + parts = append(parts, nested) + } + } + return strings.Join(parts, "\n") + } + if node.Type == gjson.JSON { + if text := node.Get("text").String(); text != "" { + return text + } + if nested := collectTextFromNode(node.Get("content")); nested != "" { + return nested + } + return node.Raw + } + return node.String() +} + +type githubCopilotResponsesStreamToolState struct { + Index int + ID string + Name string +} + +type githubCopilotResponsesStreamState struct { + MessageStarted bool + MessageStopSent bool + TextBlockStarted bool + TextBlockIndex int + NextContentIndex int + HasToolUse bool + OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState + ItemIDToTool map[string]*githubCopilotResponsesStreamToolState +} + +func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string { + root := gjson.ParseBytes(data) + out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` + out, _ = sjson.Set(out, "id", root.Get("id").String()) + out, _ = sjson.Set(out, "model", root.Get("model").String()) + + hasToolUse := false + if output := root.Get("output"); output.Exists() && output.IsArray() { + for _, item := range output.Array() { + switch item.Get("type").String() { + case "message": + if content := item.Get("content"); content.Exists() && content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() != "output_text" { + continue + } + text := part.Get("text").String() + if text == "" { + continue + } + block := `{"type":"text","text":""}` + block, _ = sjson.Set(block, "text", text) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + } + case "function_call": + hasToolUse = true + toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` + toolID := item.Get("call_id").String() + if toolID == "" { + toolID = item.Get("id").String() + } + toolUse, _ = sjson.Set(toolUse, "id", toolID) + toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String()) + if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) { + argObj := gjson.Parse(args) + if argObj.IsObject() { + toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw) + } + } + out, _ = sjson.SetRaw(out, "content.-1", toolUse) + } + } + } + + inputTokens := root.Get("usage.input_tokens").Int() + outputTokens := root.Get("usage.output_tokens").Int() + out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) + out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) + if hasToolUse { + out, _ = sjson.Set(out, "stop_reason", "tool_use") + } else { + out, _ = sjson.Set(out, "stop_reason", "end_turn") + } + return out +} + +func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string { + if *param == nil { + *param = &githubCopilotResponsesStreamState{ + TextBlockIndex: -1, + OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState), + ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState), + } + } + state := (*param).(*githubCopilotResponsesStreamState) + + if !bytes.HasPrefix(line, dataTag) { + return nil + } + payload := bytes.TrimSpace(line[5:]) + if bytes.Equal(payload, []byte("[DONE]")) { + return nil + } + if !gjson.ValidBytes(payload) { + return nil + } + + event := gjson.GetBytes(payload, "type").String() + results := make([]string, 0, 4) + ensureMessageStart := func() { + if state.MessageStarted { + return + } + messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}` + messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String()) + messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String()) + results = append(results, "event: message_start\ndata: "+messageStart+"\n\n") + state.MessageStarted = true + } + startTextBlockIfNeeded := func() { + if state.TextBlockStarted { + return + } + if state.TextBlockIndex < 0 { + state.TextBlockIndex = state.NextContentIndex + state.NextContentIndex++ + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + state.TextBlockStarted = true + } + stopTextBlockIfNeeded := func() { + if !state.TextBlockStarted { + return + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + state.TextBlockStarted = false + state.TextBlockIndex = -1 + } + resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState { + if itemID != "" { + if tool, ok := state.ItemIDToTool[itemID]; ok { + return tool + } + } + if tool, ok := state.OutputIndexToTool[outputIndex]; ok { + if itemID != "" { + state.ItemIDToTool[itemID] = tool + } + return tool + } + return nil + } + + switch event { + case "response.created": + ensureMessageStart() + case "response.output_text.delta": + ensureMessageStart() + startTextBlockIfNeeded() + delta := gjson.GetBytes(payload, "delta").String() + if delta != "" { + contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` + contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex) + contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta) + results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n") + } + case "response.output_item.added": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + ensureMessageStart() + stopTextBlockIfNeeded() + state.HasToolUse = true + tool := &githubCopilotResponsesStreamToolState{ + Index: state.NextContentIndex, + ID: gjson.GetBytes(payload, "item.call_id").String(), + Name: gjson.GetBytes(payload, "item.name").String(), + } + if tool.ID == "" { + tool.ID = gjson.GetBytes(payload, "item.id").String() + } + state.NextContentIndex++ + outputIndex := int(gjson.GetBytes(payload, "output_index").Int()) + state.OutputIndexToTool[outputIndex] = tool + if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" { + state.ItemIDToTool[itemID] = tool + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + case "response.output_item.delta": + item := gjson.GetBytes(payload, "item") + if item.Get("type").String() != "function_call" { + break + } + tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + partial := gjson.GetBytes(payload, "delta").String() + if partial == "" { + partial = item.Get("arguments").String() + } + if partial == "" { + break + } + inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` + inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index) + inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial) + results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n") + case "response.output_item.done": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + case "response.completed": + ensureMessageStart() + stopTextBlockIfNeeded() + if !state.MessageStopSent { + stopReason := "end_turn" + if state.HasToolUse { + stopReason = "tool_use" + } + messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason) + messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int()) + messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int()) + results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n") + results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") + state.MessageStopSent = true + } + } + + return results +} + // isHTTPSuccess checks if the status code indicates success (2xx). func isHTTPSuccess(statusCode int) bool { return statusCode >= 200 && statusCode < 300 diff --git a/internal/runtime/executor/github_copilot_executor_test.go b/internal/runtime/executor/github_copilot_executor_test.go index ef077fd6..2895c8a7 100644 --- a/internal/runtime/executor/github_copilot_executor_test.go +++ b/internal/runtime/executor/github_copilot_executor_test.go @@ -1,8 +1,10 @@ package executor import ( + "strings" "testing" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" ) @@ -52,3 +54,189 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) { }) } } + +func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") { + t.Fatal("expected openai-response source to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") { + t.Fatal("expected codex model to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) { + t.Parallel() + if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") { + t.Fatal("expected default openai source with non-codex model to use /chat/completions") + } +} + +func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`) + got := normalizeGitHubCopilotChatTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tool type = %q, want function", tools[0].Get("type").String()) + } +} + +func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`) + got := normalizeGitHubCopilotChatTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) { + t.Parallel() + body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") { + t.Fatalf("input = %q, want merged text", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) { + t.Parallel() + body := []byte(`{"input":{"foo":"bar"}}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "foo") { + t.Fatalf("input = %q, want stringified object", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("name").String() != "sum" { + t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be preserved") + } +} + +func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 2 { + t.Fatalf("tools len = %d, want 2", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String()) + } + if tools[0].Get("name").String() != "Bash" { + t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String()) + } + if tools[0].Get("description").String() != "Run commands" { + t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be set from input_schema") + } + if tools[0].Get("parameters.properties.command").Exists() != true { + t.Fatal("expected parameters.properties.command to exist") + } + if tools[1].Get("name").String() != "Read" { + t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice.type").String() != "function" { + t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String()) + } + if gjson.GetBytes(got, "tool_choice.name").String() != "sum" { + t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function"}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "type").String() != "message" { + t.Fatalf("type = %q, want message", gjson.Get(out, "type").String()) + } + if gjson.Get(out, "content.0.type").String() != "text" { + t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.text").String() != "hello" { + t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String()) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "content.0.type").String() != "tool_use" { + t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.name").String() != "sum" { + t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String()) + } + if gjson.Get(out, "stop_reason").String() != "tool_use" { + t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String()) + } +} + +func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) { + t.Parallel() + var param any + + created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m) + if len(created) == 0 || !strings.Contains(created[0], "message_start") { + t.Fatalf("created events = %#v, want message_start", created) + } + + delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m) + joinedDelta := strings.Join(delta, "") + if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") { + t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta) + } + + completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m) + joinedCompleted := strings.Join(completed, "") + if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") { + t.Fatalf("completed events = %#v, want message_delta + message_stop", completed) + } +} diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index c360b2de..41a5830c 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -16,6 +16,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig { } } +// resolveKiroAPIRegion determines the AWS region for Kiro API calls. +// Region priority: +// 1. auth.Metadata["api_region"] - explicit API region override +// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource +// 3. kiroDefaultRegion (us-east-1) - fallback +// Note: OIDC "region" is NOT used - it's for token refresh, not API calls +func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return kiroDefaultRegion + } + // Priority 1: Explicit api_region override + if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { + log.Debugf("kiro: using region %s (source: api_region)", r) + return r + } + // Priority 2: Extract from ProfileARN + if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { + if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { + log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion) + return arnRegion + } + } + // Note: OIDC "region" field is NOT used for API endpoint + // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) + // Using OIDC region for API calls causes DNS failures + log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion) + return kiroDefaultRegion +} + // kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region. // Prefer using buildKiroEndpointConfigs(region) for dynamic region support. var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion) @@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { return kiroEndpointConfigs } - // Determine API region with priority: api_region > profile_arn > region > default - region := kiroDefaultRegion - regionSource := "default" - - if auth.Metadata != nil { - // Priority 1: Explicit api_region override - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - regionSource = "api_region" - } else { - // Priority 2: Extract from ProfileARN - if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { - if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { - region = arnRegion - regionSource = "profile_arn" - } - } - // Note: OIDC "region" field is NOT used for API endpoint - // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) - // Using OIDC region for API calls causes DNS failures - } - } - - log.Debugf("kiro: using region %s (source: %s)", region, regionSource) + // Determine API region using shared resolution logic + region := resolveKiroAPIRegion(auth) // Build endpoint configs for the specified region endpointConfigs := buildKiroEndpointConfigs(region) @@ -520,7 +528,7 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) case "kiro": - // Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer) + // Body is already in Kiro format — pass through directly log.Debugf("kiro: body already in Kiro format, passing through directly") return body, false default: @@ -640,17 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: rate limiter cleared for token %s", tokenKey) - // Check for pure web_search request - // Route to MCP endpoint instead of normal Kiro API - if kiroclaude.HasWebSearchTool(req.Payload) { - log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint") - return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn) - } - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - defer reporter.trackFailure(ctx, &err) - - // Check if token is expired before making request + // Check if token is expired before making request (covers both normal and web_search paths) if e.isTokenExpired(accessToken) { log.Infof("kiro: access token expired, attempting recovery") @@ -679,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } } + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint") + return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn) + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("kiro") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -1025,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Build response in Claude format for Kiro translator // stopReason is extracted from upstream response by parseEventStream - kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + requestedModel := payloadRequestedModel(opts, req.Model) + kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason) + out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil } @@ -1068,17 +1077,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) - // Check for pure web_search request - // Route to MCP endpoint instead of normal Kiro API - if kiroclaude.HasWebSearchTool(req.Payload) { - log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") - return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) - } - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - defer reporter.trackFailure(ctx, &err) - - // Check if token is expired before making request + // Check if token is expired before making request (covers both normal and web_search paths) if e.isTokenExpired(accessToken) { log.Infof("kiro: access token expired, attempting recovery before stream request") @@ -1107,6 +1106,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } } + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") + return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("kiro") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -1423,7 +1432,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // So we always enable thinking parsing for Kiro responses log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) - e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) + e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled) }(httpResp, thinkingEnabled) return out, nil @@ -4114,6 +4123,238 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } +// ══════════════════════════════════════════════════════════════════════════════ +// Web Search Handler (MCP API) +// ══════════════════════════════════════════════════════════════════════════════ + +// fetchToolDescription caching: +// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time, +// with automatic retry on failure: +// - On failure, fetched stays false so subsequent calls will retry +// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path) +// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(), +// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests. +var ( + toolDescMu sync.Mutex + toolDescFetched atomic.Bool +) + +// fetchToolDescription calls MCP tools/list to get the web_search tool description +// and caches it. Safe to call concurrently — only one goroutine fetches at a time. +// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. +// The httpClient parameter allows reusing a shared pooled HTTP client. +func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) { + // Fast path: already fetched successfully, no lock needed + if toolDescFetched.Load() { + return + } + + toolDescMu.Lock() + defer toolDescMu.Unlock() + + // Double-check after acquiring lock + if toolDescFetched.Load() { + return + } + + handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs) + reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) + log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) + + req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody)) + if err != nil { + log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) + return + } + + // Reuse same headers as callMcpAPI + handler.setMcpHeaders(req) + + resp, err := handler.httpClient.Do(req) + if err != nil { + log.Warnf("kiro/websearch: tools/list request failed: %v", err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil || resp.StatusCode != http.StatusOK { + log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) + return + } + log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) + + // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} + var result struct { + Result *struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { + log.Warnf("kiro/websearch: failed to parse tools/list response") + return + } + + for _, tool := range result.Result.Tools { + if tool.Name == "web_search" && tool.Description != "" { + kiroclaude.SetWebSearchDescription(tool.Description) + toolDescFetched.Store(true) // success — no more fetches + log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) + return + } + } + + // web_search tool not found in response + log.Warnf("kiro/websearch: web_search tool not found in tools/list response") +} + +// webSearchHandler handles web search requests via Kiro MCP API +type webSearchHandler struct { + ctx context.Context + mcpEndpoint string + httpClient *http.Client + authToken string + auth *cliproxyauth.Auth // for applyDynamicFingerprint + authAttrs map[string]string // optional, for custom headers from auth.Attributes +} + +// newWebSearchHandler creates a new webSearchHandler. +// If httpClient is nil, a default client with 30s timeout is used. +// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. +func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler { + if httpClient == nil { + httpClient = &http.Client{ + Timeout: 30 * time.Second, + } + } + return &webSearchHandler{ + ctx: ctx, + mcpEndpoint: mcpEndpoint, + httpClient: httpClient, + authToken: authToken, + auth: auth, + authAttrs: authAttrs, + } +} + +// setMcpHeaders sets standard MCP API headers on the request, +// aligned with the GAR request pattern. +func (h *webSearchHandler) setMcpHeaders(req *http.Request) { + // 1. Content-Type & Accept (aligned with GAR) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + + // 2. Kiro-specific headers (aligned with GAR) + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + + // 3. User-Agent: Reuse applyDynamicFingerprint for consistency + applyDynamicFingerprint(req, h.auth) + + // 4. AWS SDK identifiers + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + // 5. Authentication + req.Header.Set("Authorization", "Bearer "+h.authToken) + + // 6. Custom headers from auth attributes + util.ApplyCustomHeadersFromAttrs(req, h.authAttrs) +} + +// mcpMaxRetries is the maximum number of retries for MCP API calls. +const mcpMaxRetries = 2 + +// callMcpAPI calls the Kiro MCP API with the given request. +// Includes retry logic with exponential backoff for retryable errors. +func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) { + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCP request: %w", err) + } + log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody)) + + var lastErr error + for attempt := 0; attempt <= mcpMaxRetries; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 10*time.Second { + backoff = 10 * time.Second + } + log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) + select { + case <-h.ctx.Done(): + return nil, h.ctx.Err() + case <-time.After(backoff): + } + } + + req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + h.setMcpHeaders(req) + + resp, err := h.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("MCP API request failed: %w", err) + continue // network error → retry + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("failed to read MCP response: %w", err) + continue // read error → retry + } + log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) + + // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) + if resp.StatusCode >= 502 && resp.StatusCode <= 504 { + lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) + } + + var mcpResponse kiroclaude.McpResponse + if err := json.Unmarshal(body, &mcpResponse); err != nil { + return nil, fmt.Errorf("failed to parse MCP response: %w", err) + } + + if mcpResponse.Error != nil { + code := -1 + if mcpResponse.Error.Code != nil { + code = *mcpResponse.Error.Code + } + msg := "Unknown error" + if mcpResponse.Error.Message != nil { + msg = *mcpResponse.Error.Message + } + return nil, fmt.Errorf("MCP error %d: %s", code, msg) + } + + return &mcpResponse, nil + } + + return nil, lastErr +} + +// webSearchAuthAttrs extracts auth attributes for MCP calls. +// Used by handleWebSearch and handleWebSearchStream to pass custom headers. +func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string { + if auth != nil { + return auth.Attributes + } + return nil +} + const maxWebSearchIterations = 5 // handleWebSearchStream handles web_search requests: @@ -4136,58 +4377,63 @@ func (e *KiroExecutor) handleWebSearchStream( return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn) } - // Build MCP endpoint based on region - region := kiroDefaultRegion - if auth != nil && auth.Metadata != nil { - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - } - } - mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) // ── Step 1: tools/list (SYNC) — cache tool description ── { - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) } // Create output channel out := make(chan cliproxyexecutor.StreamChunk) + // Usage reporting: track web search requests like normal streaming requests + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + go func() { + var wsErr error + defer reporter.trackFailure(ctx, &wsErr) defer close(out) - // Send message_start event to client - messageStartEvent := kiroclaude.SseEvent{ - Event: "message_start", - Data: map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": kiroclaude.GenerateMessageID(), - "type": "message", - "role": "assistant", - "model": req.Model, - "content": []interface{}{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]interface{}{ - "input_tokens": len(req.Payload) / 4, - "output_tokens": 0, - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0, - }, - }, - }, + // Estimate input tokens using tokenizer (matching streamToChannel pattern) + var totalUsage usage.Detail + if enc, tokErr := getTokenizer(req.Model); tokErr == nil { + if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) + } + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) } + if totalUsage.InputTokens == 0 && len(req.Payload) > 0 { + totalUsage.InputTokens = 1 + } + var accumulatedOutputLen int + defer func() { + if wsErr != nil { + return // let trackFailure handle failure reporting + } + totalUsage.OutputTokens = int64(accumulatedOutputLen / 4) + if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + reporter.publish(ctx, totalUsage) + }() + + // Send message_start event to client (aligned with streamToChannel pattern) + // Use payloadRequestedModel to return user's original model alias + msgStart := kiroclaude.BuildClaudeMessageStartEvent( + payloadRequestedModel(opts, req.Model), + totalUsage.InputTokens, + ) select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}: } // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ── @@ -4211,19 +4457,15 @@ func (e *KiroExecutor) handleWebSearchStream( currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) for iteration := 0; iteration < maxWebSearchIterations; iteration++ { - log.Infof("kiro/websearch: search iteration %d/%d — query: %s", - iteration+1, maxWebSearchIterations, currentQuery) + log.Infof("kiro/websearch: search iteration %d/%d", + iteration+1, maxWebSearchIterations) // MCP search _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) var searchResults *kiroclaude.WebSearchResults if mcpErr != nil { @@ -4245,7 +4487,7 @@ func (e *KiroExecutor) handleWebSearchStream( select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: event}: } } contentBlockIndex += 2 @@ -4255,8 +4497,9 @@ func (e *KiroExecutor) handleWebSearchStream( currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults) if err != nil { log.Warnf("kiro/websearch: failed to inject tool results: %v", err) + wsErr = fmt.Errorf("failed to inject tool results: %w", err) e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - break + return } // Call GAR with modified Claude payload (full translation pipeline) @@ -4265,14 +4508,15 @@ func (e *KiroExecutor) handleWebSearchStream( kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn) if kiroErr != nil { log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr) + wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr) e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - break + return } // Analyze response analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks) - log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v, query: %s, toolUseId: %s", - iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse, analysis.WebSearchQuery, analysis.WebSearchToolUseId) + log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v", + iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse) if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations { // Model wants another search @@ -4297,12 +4541,14 @@ func (e *KiroExecutor) handleWebSearchStream( if !shouldForward { continue } + accumulatedOutputLen += len(adjusted) select { case <-ctx.Done(): return case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}: } } else { + accumulatedOutputLen += len(chunk) select { case <-ctx.Done(): return @@ -4320,8 +4566,103 @@ func (e *KiroExecutor) handleWebSearchStream( return out, nil } +// handleWebSearch handles web_search requests for non-streaming Execute path. +// Performs MCP search synchronously, injects results into the request payload, +// then calls the normal non-streaming Kiro API path which returns a proper +// Claude JSON response (not SSE chunks). +func (e *KiroExecutor) handleWebSearch( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") + // Fall through to normal non-streaming path + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) + + // Step 1: Fetch/cache tool description (sync) + { + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + } + + // Step 2: Perform MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(query) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + log.Infof("kiro/websearch: non-stream: got %d search results", resultCount) + + // Step 3: Replace restrictive web_search tool description (align with streaming path) + simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) + if simplifyErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr) + simplifiedPayload = bytes.Clone(req.Payload) + } + + // Step 4: Inject search tool_use + tool_result into Claude payload + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults) + if err != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry) + // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream + // to produce a proper Claude JSON response + modifiedReq := req + modifiedReq.Payload = modifiedPayload + + resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if err != nil { + return resp, err + } + + // Step 6: Inject server_tool_use + web_search_tool_result into response + // so Claude Code can display "Did X searches in Ys" + indicators := []kiroclaude.SearchIndicator{ + { + ToolUseID: currentToolUseId, + Query: query, + Results: searchResults, + }, + } + injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) + if injErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) + } else { + resp.Payload = injectedPayload + } + + return resp, nil +} + // callKiroAndBuffer calls the Kiro API and buffers all response chunks. // Returns the buffered chunks for analysis before forwarding to client. +// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter. func (e *KiroExecutor) callKiroAndBuffer( ctx context.Context, auth *cliproxyauth.Auth, @@ -4338,10 +4679,7 @@ func (e *KiroExecutor) callKiroAndBuffer( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } + tokenKey := getTokenKey(auth) kiroStream, err := e.executeStreamWithRetry( ctx, auth, req, opts, accessToken, effectiveProfileArn, @@ -4367,51 +4705,6 @@ func (e *KiroExecutor) callKiroAndBuffer( return chunks, nil } -// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation). -// Used in the web search loop where the payload is modified directly in Kiro format. -func (e *KiroExecutor) callKiroRawAndBuffer( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, - kiroBody []byte, -) ([][]byte, error) { - kiroModelID := e.mapModelToKiro(req.Model) - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } - log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody)) - - kiroFormat := sdktranslator.FromString("kiro") - kiroStream, err := e.executeStreamWithRetry( - ctx, auth, req, opts, accessToken, effectiveProfileArn, - nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, - ) - if err != nil { - return nil, err - } - - // Buffer all chunks - var chunks [][]byte - for chunk := range kiroStream { - if chunk.Err != nil { - return chunks, chunk.Err - } - if len(chunk.Payload) > 0 { - chunks = append(chunks, bytes.Clone(chunk.Payload)) - } - } - - log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks)) - - return chunks, nil -} - // callKiroDirectStream creates a direct streaming channel to Kiro API without search. func (e *KiroExecutor) callKiroDirectStream( ctx context.Context, @@ -4428,18 +4721,22 @@ func (e *KiroExecutor) callKiroDirectStream( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } + tokenKey := getTokenKey(auth) - return e.executeStreamWithRetry( + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + var streamErr error + defer reporter.trackFailure(ctx, &streamErr) + + stream, streamErr := e.executeStreamWithRetry( ctx, auth, req, opts, accessToken, effectiveProfileArn, - nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, + nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey, ) + return stream, streamErr } // sendFallbackText sends a simple text response when the Kiro API fails during the search loop. +// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment +// with how streamToChannel() uses BuildClaude*Event() functions. func (e *KiroExecutor) sendFallbackText( ctx context.Context, out chan<- cliproxyexecutor.StreamChunk, @@ -4447,182 +4744,14 @@ func (e *KiroExecutor) sendFallbackText( query string, searchResults *kiroclaude.WebSearchResults, ) { - // Generate a simple text summary from search results - summary := kiroclaude.FormatSearchContextPrompt(query, searchResults) - - events := []kiroclaude.SseEvent{ - { - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": contentBlockIndex, - "content_block": map[string]interface{}{ - "type": "text", - "text": "", - }, - }, - }, - { - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": contentBlockIndex, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": summary, - }, - }, - }, - { - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": contentBlockIndex, - }, - }, - } - + events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults) for _, event := range events { select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}: } } - - // Send message_delta with end_turn and message_stop - msgDelta := kiroclaude.SseEvent{ - Event: "message_delta", - Data: map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]interface{}{ - "output_tokens": len(summary) / 4, - }, - }, - } - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}: - } - - msgStop := kiroclaude.SseEvent{ - Event: "message_stop", - Data: map[string]interface{}{ - "type": "message_stop", - }, - } - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}: - } - -} - -// handleWebSearch handles web_search requests for non-streaming Execute path. -// Performs MCP search synchronously, injects results into the request payload, -// then calls the normal non-streaming Kiro API path which returns a proper -// Claude JSON response (not SSE chunks). -func (e *KiroExecutor) handleWebSearch( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (cliproxyexecutor.Response, error) { - // Extract search query from Claude Code's web_search tool_use - query := kiroclaude.ExtractSearchQuery(req.Payload) - if query == "" { - log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") - // Fall through to normal non-streaming path - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Build MCP endpoint based on region - region := kiroDefaultRegion - if auth != nil && auth.Metadata != nil { - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - } - } - mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) - - // Step 1: Fetch/cache tool description (sync) - { - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - } - - // Step 2: Perform MCP search - _, mcpRequest := kiroclaude.CreateMcpRequest(query) - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) - - var searchResults *kiroclaude.WebSearchResults - if mcpErr != nil { - log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) - } else { - searchResults = kiroclaude.ParseSearchResults(mcpResponse) - } - - resultCount := 0 - if searchResults != nil { - resultCount = len(searchResults.Results) - } - log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query) - - // Step 3: Inject search tool_use + tool_result into Claude payload - currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) - modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults) - if err != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry) - // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream - // to produce a proper Claude JSON response - modifiedReq := req - modifiedReq.Payload = modifiedPayload - - resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) - if err != nil { - return resp, err - } - - // Step 5: Inject server_tool_use + web_search_tool_result into response - // so Claude Code can display "Did X searches in Ys" - indicators := []kiroclaude.SearchIndicator{ - { - ToolUseID: currentToolUseId, - Query: query, - Results: searchResults, - }, - } - injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) - if injErr != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) - } else { - resp.Payload = injectedPayload - } - - return resp, nil } // executeNonStreamFallback runs the standard non-streaming Execute path for a request. diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go index 84fd6621..c86b6e02 100644 --- a/internal/translator/kiro/claude/kiro_claude_stream.go +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -183,4 +183,124 @@ func PendingTagSuffix(buffer, tag string) int { } } return 0 -} \ No newline at end of file +} + +// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events +// (server_tool_use + web_search_tool_result) without text summary or message termination. +// These events trigger Claude Code's search indicator UI. +// The caller is responsible for sending message_start before and message_delta/stop after. +func GenerateSearchIndicatorEvents( + query string, + toolUseID string, + searchResults *WebSearchResults, + startIndex int, +) [][]byte { + events := make([][]byte, 0, 5) + + // 1. content_block_start (server_tool_use) + event1 := map[string]interface{}{ + "type": "content_block_start", + "index": startIndex, + "content_block": map[string]interface{}{ + "id": toolUseID, + "type": "server_tool_use", + "name": "web_search", + "input": map[string]interface{}{}, + }, + } + data1, _ := json.Marshal(event1) + events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n")) + + // 2. content_block_delta (input_json_delta) + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + event2 := map[string]interface{}{ + "type": "content_block_delta", + "index": startIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + } + data2, _ := json.Marshal(event2) + events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n")) + + // 3. content_block_stop (server_tool_use) + event3 := map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex, + } + data3, _ := json.Marshal(event3) + events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n")) + + // 4. content_block_start (web_search_tool_result) + searchContent := make([]map[string]interface{}, 0) + if searchResults != nil { + for _, r := range searchResults.Results { + snippet := "" + if r.Snippet != nil { + snippet = *r.Snippet + } + searchContent = append(searchContent, map[string]interface{}{ + "type": "web_search_result", + "title": r.Title, + "url": r.URL, + "encrypted_content": snippet, + "page_age": nil, + }) + } + } + event4 := map[string]interface{}{ + "type": "content_block_start", + "index": startIndex + 1, + "content_block": map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": toolUseID, + "content": searchContent, + }, + } + data4, _ := json.Marshal(event4) + events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n")) + + // 5. content_block_stop (web_search_tool_result) + event5 := map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex + 1, + } + data5, _ := json.Marshal(event5) + events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n")) + + return events +} + +// BuildFallbackTextEvents generates SSE events for a fallback text response +// when the Kiro API fails during the search loop. Uses BuildClaude*Event() +// functions to align with streamToChannel patterns. +// Returns raw SSE byte slices ready to be sent to the client channel. +func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte { + summary := FormatSearchContextPrompt(query, results) + outputTokens := len(summary) / 4 + if len(summary) > 0 && outputTokens == 0 { + outputTokens = 1 + } + + var events [][]byte + + // content_block_start (text) + events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")) + + // content_block_delta (text_delta) + events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex)) + + // content_block_stop + events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex)) + + // message_delta with end_turn + events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{ + OutputTokens: int64(outputTokens), + })) + + // message_stop + events = append(events, BuildClaudeMessageStopOnlyEvent()) + + return events +} diff --git a/internal/translator/kiro/claude/kiro_claude_stream_parser.go b/internal/translator/kiro/claude/kiro_claude_stream_parser.go new file mode 100644 index 00000000..275196ac --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_stream_parser.go @@ -0,0 +1,350 @@ +package claude + +import ( + "encoding/json" + "strings" + + log "github.com/sirupsen/logrus" +) + +// sseEvent represents a Server-Sent Event +type sseEvent struct { + Event string + Data interface{} +} + +// ToSSEString converts the event to SSE wire format +func (e *sseEvent) ToSSEString() string { + dataBytes, _ := json.Marshal(e.Data) + return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n" +} + +// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset. +// It also suppresses duplicate message_start events (returns shouldForward=false). +// This is used to combine search indicator events (indices 0,1) with Kiro model response events. +// +// The data parameter is a single SSE "data:" line payload (JSON). +// Returns: adjusted data, shouldForward (false = skip this event). +func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) { + if len(data) == 0 { + return data, true + } + + // Quick check: parse the JSON + var event map[string]interface{} + if err := json.Unmarshal(data, &event); err != nil { + // Not valid JSON, pass through + return data, true + } + + eventType, _ := event["type"].(string) + + // Suppress duplicate message_start events + if eventType == "message_start" { + return data, false + } + + // Adjust index for content_block events + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + offset + adjusted, err := json.Marshal(event) + if err != nil { + return data, true + } + return adjusted, true + } + } + + // Pass through all other events unchanged (message_delta, message_stop, ping, etc.) + return data, true +} + +// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs) +// and adjusts content block indices. Suppresses duplicate message_start events. +// Returns the adjusted chunk and whether it should be forwarded. +func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) { + chunkStr := string(chunk) + + // Fast path: if no "data:" prefix, pass through + if !strings.Contains(chunkStr, "data: ") { + return chunk, true + } + + var result strings.Builder + hasContent := false + + lines := strings.Split(chunkStr, "\n") + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + result.WriteString(line + "\n") + hasContent = true + continue + } + + adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset) + if !shouldForward { + // Skip this event and its preceding "event:" line + // Also skip the trailing empty line + continue + } + + result.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + dataPayload := strings.TrimPrefix(lines[i+1], "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err == nil { + if eventType, ok := event["type"].(string); ok && eventType == "message_start" { + // Skip both the event: and data: lines + i++ // skip the data: line too + continue + } + } + } + result.WriteString(line + "\n") + hasContent = true + } else { + result.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if !hasContent { + return nil, false + } + + return []byte(result.String()), true +} + +// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response. +type BufferedStreamResult struct { + // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use") + StopReason string + // WebSearchQuery is the extracted query if the model requested another web_search + WebSearchQuery string + // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults) + WebSearchToolUseId string + // HasWebSearchToolUse indicates whether the model requested web_search + HasWebSearchToolUse bool + // WebSearchToolUseIndex is the content_block index of the web_search tool_use + WebSearchToolUseIndex int +} + +// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use. +// This is used in the search loop to determine if the model wants another search round. +func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { + result := BufferedStreamResult{WebSearchToolUseIndex: -1} + + // Track tool use state across chunks + var currentToolName string + var currentToolIndex int = -1 + var toolInputBuilder strings.Builder + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + for _, line := range lines { + if !strings.HasPrefix(line, "data: ") { + continue + } + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + if dataPayload == "[DONE]" || dataPayload == "" { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "message_delta": + // Extract stop_reason from message_delta + if delta, ok := event["delta"].(map[string]interface{}); ok { + if sr, ok := delta["stop_reason"].(string); ok && sr != "" { + result.StopReason = sr + } + } + + case "content_block_start": + // Detect tool_use content blocks + if cb, ok := event["content_block"].(map[string]interface{}); ok { + if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" { + if name, ok := cb["name"].(string); ok { + currentToolName = strings.ToLower(name) + if idx, ok := event["index"].(float64); ok { + currentToolIndex = int(idx) + } + // Capture tool use ID for toolResults handshake + if id, ok := cb["id"].(string); ok { + result.WebSearchToolUseId = id + } + toolInputBuilder.Reset() + } + } + } + + case "content_block_delta": + // Accumulate tool input JSON + if currentToolName != "" { + if delta, ok := event["delta"].(map[string]interface{}); ok { + if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" { + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuilder.WriteString(partial) + } + } + } + } + + case "content_block_stop": + // Finalize tool use detection + if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" { + result.HasWebSearchToolUse = true + result.WebSearchToolUseIndex = currentToolIndex + // Extract query from accumulated input JSON + inputJSON := toolInputBuilder.String() + var input map[string]string + if err := json.Unmarshal([]byte(inputJSON), &input); err == nil { + if q, ok := input["query"]; ok { + result.WebSearchQuery = q + } + } + log.Debugf("kiro/websearch: detected web_search tool_use") + } + currentToolName = "" + currentToolIndex = -1 + toolInputBuilder.Reset() + } + } + } + + return result +} + +// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use +// content blocks. This prevents the client from seeing "Tool use" prompts for web_search +// when the proxy is handling the search loop internally. +// Also suppresses message_start and message_delta/message_stop events since those +// are managed by the outer handleWebSearchStream. +func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte { + var filtered [][]byte + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + + var resultBuilder strings.Builder + hasContent := false + + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + // Skip [DONE] — the outer loop manages stream termination + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + resultBuilder.WriteString(line + "\n") + hasContent = true + continue + } + + eventType, _ := event["type"].(string) + + // Skip message_start (outer loop sends its own) + if eventType == "message_start" { + continue + } + + // Skip message_delta and message_stop (outer loop manages these) + if eventType == "message_delta" || eventType == "message_stop" { + continue + } + + // Check if this event belongs to the web_search tool_use block + if wsToolIndex >= 0 { + if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex { + // Skip events for the web_search tool_use block + continue + } + } + + // Apply index offset for remaining events + if indexOffset > 0 { + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + indexOffset + adjusted, err := json.Marshal(event) + if err == nil { + resultBuilder.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + continue + } + } + } + } + + resultBuilder.WriteString(line + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + nextData := strings.TrimPrefix(lines[i+1], "data: ") + nextData = strings.TrimSpace(nextData) + + var nextEvent map[string]interface{} + if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil { + nextType, _ := nextEvent["type"].(string) + if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" { + i++ // skip the data line + continue + } + if wsToolIndex >= 0 { + if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex { + i++ // skip the data line + continue + } + } + } + } + resultBuilder.WriteString(line + "\n") + hasContent = true + } else { + resultBuilder.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if hasContent { + filtered = append(filtered, []byte(resultBuilder.String())) + } + } + + return filtered +} diff --git a/internal/translator/kiro/claude/kiro_websearch.go b/internal/translator/kiro/claude/kiro_websearch.go index 25be730e..b9da3829 100644 --- a/internal/translator/kiro/claude/kiro_websearch.go +++ b/internal/translator/kiro/claude/kiro_websearch.go @@ -1,11 +1,14 @@ // Package claude provides web search functionality for Kiro translator. -// This file implements detection and MCP request/response types for web search. +// This file implements detection, MCP request/response types, and pure data +// transformation utilities for web search. SSE event generation, stream analysis, +// and HTTP I/O logic reside in the executor package (kiro_executor.go). package claude import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/google/uuid" @@ -14,6 +17,26 @@ import ( "github.com/tidwall/sjson" ) +// cachedToolDescription stores the dynamically-fetched web_search tool description. +// Written by the executor via SetWebSearchDescription, read by the translator +// when building the remote_web_search tool for Kiro API requests. +var cachedToolDescription atomic.Value // stores string + +// GetWebSearchDescription returns the cached web_search tool description, +// or empty string if not yet fetched. Lock-free via atomic.Value. +func GetWebSearchDescription() string { + if v := cachedToolDescription.Load(); v != nil { + return v.(string) + } + return "" +} + +// SetWebSearchDescription stores the dynamically-fetched web_search tool description. +// Called by the executor after fetching from MCP tools/list. +func SetWebSearchDescription(desc string) { + cachedToolDescription.Store(desc) +} + // McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API type McpRequest struct { ID string `json:"id"` @@ -191,36 +214,11 @@ func CreateMcpRequest(query string) (string, *McpRequest) { return toolUseID, request } -// GenerateMessageID generates a Claude-style message ID -func GenerateMessageID() string { - return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24] -} - // GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID) func GenerateToolUseID() string { return strings.ReplaceAll(uuid.New().String(), "-", "")[:22] } -// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools). -// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays. -func ContainsWebSearchTool(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if !tools.IsArray() { - return false - } - - for _, tool := range tools.Array() { - name := strings.ToLower(tool.Get("name").String()) - toolType := strings.ToLower(tool.Get("type").String()) - - if isWebSearchTool(name, toolType) { - return true - } - } - - return false -} - // ReplaceWebSearchToolDescription replaces the web_search tool description with // a minimal version that allows re-search without the restrictive "do not search // non-coding topics" instruction from the original Kiro tools/list response. @@ -275,48 +273,6 @@ func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) { return result, nil } -// StripWebSearchTool removes web_search tool entries from the request's tools array. -// If the tools array becomes empty after removal, it is removed entirely. -func StripWebSearchTool(body []byte) ([]byte, error) { - tools := gjson.GetBytes(body, "tools") - if !tools.IsArray() { - return body, nil - } - - var filtered []json.RawMessage - for _, tool := range tools.Array() { - name := strings.ToLower(tool.Get("name").String()) - toolType := strings.ToLower(tool.Get("type").String()) - - if !isWebSearchTool(name, toolType) { - filtered = append(filtered, json.RawMessage(tool.Raw)) - } - } - - var result []byte - var err error - - if len(filtered) == 0 { - // Remove tools array entirely - result, err = sjson.DeleteBytes(body, "tools") - if err != nil { - return body, fmt.Errorf("failed to delete tools: %w", err) - } - } else { - // Replace with filtered array - filteredJSON, marshalErr := json.Marshal(filtered) - if marshalErr != nil { - return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr) - } - result, err = sjson.SetRawBytes(body, "tools", filteredJSON) - if err != nil { - return body, fmt.Errorf("failed to set filtered tools: %w", err) - } - } - - return result, nil -} - // FormatSearchContextPrompt formats search results as a structured text block // for injection into the system prompt. func FormatSearchContextPrompt(query string, results *WebSearchResults) string { @@ -365,7 +321,7 @@ func FormatToolResultText(results *WebSearchResults) string { // // This produces the exact same GAR request format as the Kiro IDE (HAR captures). // IMPORTANT: The web_search tool must remain in the "tools" array for this to work. -// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available. +// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description. func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) { var payload map[string]interface{} if err := json.Unmarshal(claudePayload, &payload); err != nil { @@ -432,8 +388,8 @@ Do NOT apologize for bad results without first attempting a re-search. return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err) } - log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, query=%s, messages=%d)", - toolUseId, query, len(messages)) + log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)", + toolUseId, len(messages)) return result, nil } @@ -512,658 +468,28 @@ type SearchIndicator struct { Results *WebSearchResults } -// ══════════════════════════════════════════════════════════════════════════════ -// SSE Event Generation -// ══════════════════════════════════════════════════════════════════════════════ - -// SseEvent represents a Server-Sent Event -type SseEvent struct { - Event string - Data interface{} +// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region. +// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream. +func BuildMcpEndpoint(region string) string { + return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) } -// ToSSEString converts the event to SSE wire format -func (e *SseEvent) ToSSEString() string { - dataBytes, _ := json.Marshal(e.Data) - return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes)) -} - -// GenerateWebSearchEvents generates the 11-event SSE sequence for web search. -// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json), -// content_block_stop, content_block_start(web_search_tool_result), content_block_stop, -// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop -func GenerateWebSearchEvents( - model string, - query string, - toolUseID string, - searchResults *WebSearchResults, - inputTokens int, -) []SseEvent { - events := make([]SseEvent, 0, 15) - messageID := GenerateMessageID() - - // 1. message_start - events = append(events, SseEvent{ - Event: "message_start", - Data: map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": messageID, - "type": "message", - "role": "assistant", - "model": model, - "content": []interface{}{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]interface{}{ - "input_tokens": inputTokens, - "output_tokens": 0, - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0, - }, - }, - }, - }) - - // 2. content_block_start (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 0, - "content_block": map[string]interface{}{ - "id": toolUseID, - "type": "server_tool_use", - "name": "web_search", - "input": map[string]interface{}{}, - }, - }, - }) - - // 3. content_block_delta (input_json_delta) - inputJSON, _ := json.Marshal(map[string]string{"query": query}) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": string(inputJSON), - }, - }, - }) - - // 4. content_block_stop (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 0, - }, - }) - - // 5. content_block_start (web_search_tool_result) - searchContent := make([]map[string]interface{}, 0) - if searchResults != nil { - for _, r := range searchResults.Results { - snippet := "" - if r.Snippet != nil { - snippet = *r.Snippet - } - searchContent = append(searchContent, map[string]interface{}{ - "type": "web_search_result", - "title": r.Title, - "url": r.URL, - "encrypted_content": snippet, - "page_age": nil, - }) - } - } - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 1, - "content_block": map[string]interface{}{ - "type": "web_search_tool_result", - "tool_use_id": toolUseID, - "content": searchContent, - }, - }, - }) - - // 6. content_block_stop (web_search_tool_result) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 1, - }, - }) - - // 7. content_block_start (text) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 2, - "content_block": map[string]interface{}{ - "type": "text", - "text": "", - }, - }, - }) - - // 8. content_block_delta (text_delta) - generate search summary - summary := generateSearchSummary(query, searchResults) - - // Split text into chunks for streaming effect - chunkSize := 100 - runes := []rune(summary) - for i := 0; i < len(runes); i += chunkSize { - end := i + chunkSize - if end > len(runes) { - end = len(runes) - } - chunk := string(runes[i:end]) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": 2, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": chunk, - }, - }, - }) - } - - // 9. content_block_stop (text) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 2, - }, - }) - - // 10. message_delta - outputTokens := (len(summary) + 3) / 4 // Simple estimation - events = append(events, SseEvent{ - Event: "message_delta", - Data: map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]interface{}{ - "output_tokens": outputTokens, - }, - }, - }) - - // 11. message_stop - events = append(events, SseEvent{ - Event: "message_stop", - Data: map[string]interface{}{ - "type": "message_stop", - }, - }) - - return events -} - -// generateSearchSummary generates a text summary of search results -func generateSearchSummary(query string, results *WebSearchResults) string { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query)) - - if results != nil && len(results.Results) > 0 { - for i, r := range results.Results { - sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title)) - if r.Snippet != nil { - snippet := *r.Snippet - if len(snippet) > 200 { - snippet = snippet[:200] + "..." - } - sb.WriteString(fmt.Sprintf(" %s\n", snippet)) - } - sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL)) - } - } else { - sb.WriteString("No results found.\n") - } - - sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.") - - return sb.String() -} - -// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events -// (server_tool_use + web_search_tool_result) without text summary or message termination. -// These events trigger Claude Code's search indicator UI. -// The caller is responsible for sending message_start before and message_delta/stop after. -func GenerateSearchIndicatorEvents( - query string, - toolUseID string, - searchResults *WebSearchResults, - startIndex int, -) []SseEvent { - events := make([]SseEvent, 0, 4) - - // 1. content_block_start (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex, - "content_block": map[string]interface{}{ - "id": toolUseID, - "type": "server_tool_use", - "name": "web_search", - "input": map[string]interface{}{}, - }, - }, - }) - - // 2. content_block_delta (input_json_delta) - inputJSON, _ := json.Marshal(map[string]string{"query": query}) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": startIndex, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": string(inputJSON), - }, - }, - }) - - // 3. content_block_stop (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex, - }, - }) - - // 4. content_block_start (web_search_tool_result) - searchContent := make([]map[string]interface{}, 0) - if searchResults != nil { - for _, r := range searchResults.Results { - snippet := "" - if r.Snippet != nil { - snippet = *r.Snippet - } - searchContent = append(searchContent, map[string]interface{}{ - "type": "web_search_result", - "title": r.Title, - "url": r.URL, - "encrypted_content": snippet, - "page_age": nil, - }) - } - } - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex + 1, - "content_block": map[string]interface{}{ - "type": "web_search_tool_result", - "tool_use_id": toolUseID, - "content": searchContent, - }, - }, - }) - - // 5. content_block_stop (web_search_tool_result) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex + 1, - }, - }) - - return events -} - -// ══════════════════════════════════════════════════════════════════════════════ -// Stream Analysis & Manipulation -// ══════════════════════════════════════════════════════════════════════════════ - -// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset. -// It also suppresses duplicate message_start events (returns shouldForward=false). -// This is used to combine search indicator events (indices 0,1) with Kiro model response events. -// -// The data parameter is a single SSE "data:" line payload (JSON). -// Returns: adjusted data, shouldForward (false = skip this event). -func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) { - if len(data) == 0 { - return data, true - } - - // Quick check: parse the JSON - var event map[string]interface{} - if err := json.Unmarshal(data, &event); err != nil { - // Not valid JSON, pass through - return data, true - } - - eventType, _ := event["type"].(string) - - // Suppress duplicate message_start events - if eventType == "message_start" { - return data, false - } - - // Adjust index for content_block events - switch eventType { - case "content_block_start", "content_block_delta", "content_block_stop": - if idx, ok := event["index"].(float64); ok { - event["index"] = int(idx) + offset - adjusted, err := json.Marshal(event) - if err != nil { - return data, true - } - return adjusted, true - } - } - - // Pass through all other events unchanged (message_delta, message_stop, ping, etc.) - return data, true -} - -// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs) -// and adjusts content block indices. Suppresses duplicate message_start events. -// Returns the adjusted chunk and whether it should be forwarded. -func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) { - chunkStr := string(chunk) - - // Fast path: if no "data:" prefix, pass through - if !strings.Contains(chunkStr, "data: ") { - return chunk, true - } - - var result strings.Builder - hasContent := false - - lines := strings.Split(chunkStr, "\n") - for i := 0; i < len(lines); i++ { - line := lines[i] - - if strings.HasPrefix(line, "data: ") { - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - if dataPayload == "[DONE]" { - result.WriteString(line + "\n") - hasContent = true - continue - } - - adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset) - if !shouldForward { - // Skip this event and its preceding "event:" line - // Also skip the trailing empty line - continue - } - - result.WriteString("data: " + string(adjusted) + "\n") - hasContent = true - } else if strings.HasPrefix(line, "event: ") { - // Check if the next data line will be suppressed - if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { - dataPayload := strings.TrimPrefix(lines[i+1], "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err == nil { - if eventType, ok := event["type"].(string); ok && eventType == "message_start" { - // Skip both the event: and data: lines - i++ // skip the data: line too - continue - } - } - } - result.WriteString(line + "\n") - hasContent = true - } else { - result.WriteString(line + "\n") - if strings.TrimSpace(line) != "" { - hasContent = true - } - } - } - - if !hasContent { - return nil, false - } - - return []byte(result.String()), true -} - -// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response. -type BufferedStreamResult struct { - // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use") - StopReason string - // WebSearchQuery is the extracted query if the model requested another web_search - WebSearchQuery string - // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults) - WebSearchToolUseId string - // HasWebSearchToolUse indicates whether the model requested web_search - HasWebSearchToolUse bool - // WebSearchToolUseIndex is the content_block index of the web_search tool_use - WebSearchToolUseIndex int -} - -// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use. -// This is used in the search loop to determine if the model wants another search round. -func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { - result := BufferedStreamResult{WebSearchToolUseIndex: -1} - - // Track tool use state across chunks - var currentToolName string - var currentToolIndex int = -1 - var toolInputBuilder strings.Builder - - for _, chunk := range chunks { - chunkStr := string(chunk) - lines := strings.Split(chunkStr, "\n") - for _, line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - if dataPayload == "[DONE]" || dataPayload == "" { - continue - } - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { - continue - } - - eventType, _ := event["type"].(string) - - switch eventType { - case "message_delta": - // Extract stop_reason from message_delta - if delta, ok := event["delta"].(map[string]interface{}); ok { - if sr, ok := delta["stop_reason"].(string); ok && sr != "" { - result.StopReason = sr - } - } - - case "content_block_start": - // Detect tool_use content blocks - if cb, ok := event["content_block"].(map[string]interface{}); ok { - if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" { - if name, ok := cb["name"].(string); ok { - currentToolName = strings.ToLower(name) - if idx, ok := event["index"].(float64); ok { - currentToolIndex = int(idx) - } - // Capture tool use ID for toolResults handshake - if id, ok := cb["id"].(string); ok { - result.WebSearchToolUseId = id - } - toolInputBuilder.Reset() - } - } - } - - case "content_block_delta": - // Accumulate tool input JSON - if currentToolName != "" { - if delta, ok := event["delta"].(map[string]interface{}); ok { - if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" { - if partial, ok := delta["partial_json"].(string); ok { - toolInputBuilder.WriteString(partial) - } - } - } - } - - case "content_block_stop": - // Finalize tool use detection - if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" { - result.HasWebSearchToolUse = true - result.WebSearchToolUseIndex = currentToolIndex - // Extract query from accumulated input JSON - inputJSON := toolInputBuilder.String() - var input map[string]string - if err := json.Unmarshal([]byte(inputJSON), &input); err == nil { - if q, ok := input["query"]; ok { - result.WebSearchQuery = q - } - } - log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery) - } - currentToolName = "" - currentToolIndex = -1 - toolInputBuilder.Reset() - } - } - } - - return result -} - -// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use -// content blocks. This prevents the client from seeing "Tool use" prompts for web_search -// when the proxy is handling the search loop internally. -// Also suppresses message_start and message_delta/message_stop events since those -// are managed by the outer handleWebSearchStream. -func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte { - var filtered [][]byte - - for _, chunk := range chunks { - chunkStr := string(chunk) - lines := strings.Split(chunkStr, "\n") - - var resultBuilder strings.Builder - hasContent := false - - for i := 0; i < len(lines); i++ { - line := lines[i] - - if strings.HasPrefix(line, "data: ") { - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - if dataPayload == "[DONE]" { - // Skip [DONE] — the outer loop manages stream termination - continue - } - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { - resultBuilder.WriteString(line + "\n") - hasContent = true - continue - } - - eventType, _ := event["type"].(string) - - // Skip message_start (outer loop sends its own) - if eventType == "message_start" { - continue - } - - // Skip message_delta and message_stop (outer loop manages these) - if eventType == "message_delta" || eventType == "message_stop" { - continue - } - - // Check if this event belongs to the web_search tool_use block - if wsToolIndex >= 0 { - if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex { - // Skip events for the web_search tool_use block - continue - } - } - - // Apply index offset for remaining events - if indexOffset > 0 { - switch eventType { - case "content_block_start", "content_block_delta", "content_block_stop": - if idx, ok := event["index"].(float64); ok { - event["index"] = int(idx) + indexOffset - adjusted, err := json.Marshal(event) - if err == nil { - resultBuilder.WriteString("data: " + string(adjusted) + "\n") - hasContent = true - continue - } - } - } - } - - resultBuilder.WriteString(line + "\n") - hasContent = true - } else if strings.HasPrefix(line, "event: ") { - // Check if the next data line will be suppressed - if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { - nextData := strings.TrimPrefix(lines[i+1], "data: ") - nextData = strings.TrimSpace(nextData) - - var nextEvent map[string]interface{} - if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil { - nextType, _ := nextEvent["type"].(string) - if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" { - i++ // skip the data line - continue - } - if wsToolIndex >= 0 { - if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex { - i++ // skip the data line - continue - } - } - } - } - resultBuilder.WriteString(line + "\n") - hasContent = true - } else { - resultBuilder.WriteString(line + "\n") - if strings.TrimSpace(line) != "" { - hasContent = true - } - } - } - - if hasContent { - filtered = append(filtered, []byte(resultBuilder.String())) - } - } - - return filtered +// ParseSearchResults extracts WebSearchResults from MCP response +func ParseSearchResults(response *McpResponse) *WebSearchResults { + if response == nil || response.Result == nil || len(response.Result.Content) == 0 { + return nil + } + + content := response.Result.Content[0] + if content.ContentType != "text" { + return nil + } + + var results WebSearchResults + if err := json.Unmarshal([]byte(content.Text), &results); err != nil { + log.Warnf("kiro/websearch: failed to parse search results: %v", err) + return nil + } + + return &results } diff --git a/internal/translator/kiro/claude/kiro_websearch_handler.go b/internal/translator/kiro/claude/kiro_websearch_handler.go deleted file mode 100644 index c64d8eb9..00000000 --- a/internal/translator/kiro/claude/kiro_websearch_handler.go +++ /dev/null @@ -1,270 +0,0 @@ -// Package claude provides web search handler for Kiro translator. -// This file implements the MCP API call and response handling. -package claude - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" - kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -// Cached web_search tool description fetched from MCP tools/list. -// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure: -// - sync.Once prevents race conditions and deduplicates concurrent calls -// - On failure, a fresh sync.Once is swapped in to allow retry on next call -// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls -var ( - cachedToolDescription atomic.Value // stores string - toolDescOnce atomic.Pointer[sync.Once] - fallbackFpOnce sync.Once - fallbackFp *kiroauth.Fingerprint -) - -func init() { - toolDescOnce.Store(&sync.Once{}) -} - -// FetchToolDescription calls MCP tools/list to get the web_search tool description -// and caches it. Safe to call concurrently — only one goroutine fetches at a time. -// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. -// The httpClient parameter allows reusing a shared pooled HTTP client. -func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) { - toolDescOnce.Load().Do(func() { - handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs) - reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) - log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) - - req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody)) - if err != nil { - log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - - // Reuse same headers as CallMcpAPI - handler.setMcpHeaders(req) - - resp, err := handler.HTTPClient.Do(req) - if err != nil { - log.Warnf("kiro/websearch: tools/list request failed: %v", err) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil || resp.StatusCode != http.StatusOK { - log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) - - // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} - var result struct { - Result *struct { - Tools []struct { - Name string `json:"name"` - Description string `json:"description"` - } `json:"tools"` - } `json:"result"` - } - if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { - log.Warnf("kiro/websearch: failed to parse tools/list response") - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - - for _, tool := range result.Result.Tools { - if tool.Name == "web_search" && tool.Description != "" { - cachedToolDescription.Store(tool.Description) - log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) - return // success — sync.Once stays "done", no more fetches - } - } - - // web_search tool not found in response - toolDescOnce.Store(&sync.Once{}) // allow retry - }) -} - -// GetWebSearchDescription returns the cached web_search tool description, -// or empty string if not yet fetched. Lock-free via atomic.Value. -func GetWebSearchDescription() string { - if v := cachedToolDescription.Load(); v != nil { - return v.(string) - } - return "" -} - -// WebSearchHandler handles web search requests via Kiro MCP API -type WebSearchHandler struct { - McpEndpoint string - HTTPClient *http.Client - AuthToken string - Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers - AuthAttrs map[string]string // optional, for custom headers from auth.Attributes -} - -// NewWebSearchHandler creates a new WebSearchHandler. -// If httpClient is nil, a default client with 30s timeout is used. -// If fingerprint is nil, a random one-off fingerprint is generated. -// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. -func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler { - if httpClient == nil { - httpClient = &http.Client{ - Timeout: 30 * time.Second, - } - } - if fp == nil { - // Use a shared fallback fingerprint for callers without token context - fallbackFpOnce.Do(func() { - mgr := kiroauth.NewFingerprintManager() - fallbackFp = mgr.GetFingerprint("mcp-fallback") - }) - fp = fallbackFp - } - return &WebSearchHandler{ - McpEndpoint: mcpEndpoint, - HTTPClient: httpClient, - AuthToken: authToken, - Fingerprint: fp, - AuthAttrs: authAttrs, - } -} - -// setMcpHeaders sets standard MCP API headers on the request, -// aligned with the GAR request pattern in kiro_executor.go. -func (h *WebSearchHandler) setMcpHeaders(req *http.Request) { - fp := h.Fingerprint - - // 1. Content-Type & Accept (aligned with GAR) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "*/*") - - // 2. Kiro-specific headers (aligned with GAR) - req.Header.Set("x-amzn-kiro-agent-mode", "vibe") - req.Header.Set("x-amzn-codewhisperer-optout", "true") - - // 3. Dynamic fingerprint headers - req.Header.Set("User-Agent", fp.BuildUserAgent()) - req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) - - // 4. AWS SDK identifiers (casing aligned with GAR) - req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - // 5. Authentication - req.Header.Set("Authorization", "Bearer "+h.AuthToken) - - // 6. Custom headers from auth attributes - util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs) -} - -// mcpMaxRetries is the maximum number of retries for MCP API calls. -const mcpMaxRetries = 2 - -// CallMcpAPI calls the Kiro MCP API with the given request. -// Includes retry logic with exponential backoff for retryable errors, -// aligned with the GAR request retry pattern. -func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) { - requestBody, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal MCP request: %w", err) - } - log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody)) - - var lastErr error - for attempt := 0; attempt <= mcpMaxRetries; attempt++ { - if attempt > 0 { - backoff := time.Duration(1< 10*time.Second { - backoff = 10 * time.Second - } - log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) - time.Sleep(backoff) - } - - req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - h.setMcpHeaders(req) - - resp, err := h.HTTPClient.Do(req) - if err != nil { - lastErr = fmt.Errorf("MCP API request failed: %w", err) - continue // network error → retry - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - lastErr = fmt.Errorf("failed to read MCP response: %w", err) - continue // read error → retry - } - log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) - - // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) - if resp.StatusCode >= 502 && resp.StatusCode <= 504 { - lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) - continue - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) - } - - var mcpResponse McpResponse - if err := json.Unmarshal(body, &mcpResponse); err != nil { - return nil, fmt.Errorf("failed to parse MCP response: %w", err) - } - - if mcpResponse.Error != nil { - code := -1 - if mcpResponse.Error.Code != nil { - code = *mcpResponse.Error.Code - } - msg := "Unknown error" - if mcpResponse.Error.Message != nil { - msg = *mcpResponse.Error.Message - } - return nil, fmt.Errorf("MCP error %d: %s", code, msg) - } - - return &mcpResponse, nil - } - - return nil, lastErr -} - -// ParseSearchResults extracts WebSearchResults from MCP response -func ParseSearchResults(response *McpResponse) *WebSearchResults { - if response == nil || response.Result == nil || len(response.Result.Content) == 0 { - return nil - } - - content := response.Result.Content[0] - if content.ContentType != "text" { - return nil - } - - var results WebSearchResults - if err := json.Unmarshal([]byte(content.Text), &results); err != nil { - log.Warnf("kiro/websearch: failed to parse search results: %v", err) - return nil - } - - return &results -}