From 75818b1e25f1ac8484c59ed92052cbf82bc035d4 Mon Sep 17 00:00:00 2001 From: xiluo Date: Fri, 13 Feb 2026 17:56:57 +0800 Subject: [PATCH 1/8] fix(antigravity): add warn-level logging to silent failure paths in FetchAntigravityModels Add log.Warnf calls to all 7 silent return nil paths so operators can diagnose why specific antigravity accounts fail to fetch models and get unregistered without any log trail. Covers: token errors, request creation failures, context cancellation, network errors (after exhausting fallback URLs), body read errors, unexpected HTTP status codes, and missing models field in response. --- internal/runtime/executor/antigravity_executor.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 24765740..ee20c519 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1008,6 +1008,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c exec := &AntigravityExecutor{cfg: cfg} token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth) if errToken != nil || token == "" { + log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken) return nil } if updatedAuth != nil { @@ -1021,6 +1022,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c modelsURL := baseURL + antigravityModelsPath httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`))) if errReq != nil { + log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq) return nil } httpReq.Header.Set("Content-Type", "application/json") @@ -1033,12 +1035,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { + log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo) return nil } if idx+1 < len(baseURLs) { log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo) return nil } @@ -1051,6 +1055,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead) return nil } if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { @@ -1058,11 +1063,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } + log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes)) return nil } result := gjson.GetBytes(bodyBytes, "models") if !result.Exists() { + log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes)) return nil } From 587371eb14a4b876855dadbde4277a03633f0ee7 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Thu, 12 Feb 2026 11:10:04 +0800 Subject: [PATCH 2/8] refactor: align web search with executor layer patterns Consolidate web search handler, SSE event generation, stream analysis, and MCP HTTP I/O into the executor layer. Merge the separate kiro_websearch_handler.go back into kiro_executor.go to align with the single-file-per-executor convention. Translator retains only pure data types, detection, and payload transformation. Key changes: - Move SSE construction (search indicators, fallback text, message_start) from translator to executor, consistent with streamToChannel pattern - Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription) from translator to executor alongside other HTTP I/O - Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication) - Centralize MCP endpoint URL via BuildMcpEndpoint in translator - Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache - Thread context.Context through MCP HTTP calls for cancellation support - Thread usage reporter through all web search API call paths - Add token expiry pre-check before MCP/GAR calls - Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic, ContainsWebSearchTool, StripWebSearchTool) --- internal/runtime/executor/kiro_executor.go | 766 +++++++++-------- .../kiro/claude/kiro_claude_stream.go | 127 ++- .../kiro/claude/kiro_claude_stream_parser.go | 350 ++++++++ .../translator/kiro/claude/kiro_websearch.go | 768 ++---------------- .../kiro/claude/kiro_websearch_handler.go | 270 ------ 5 files changed, 970 insertions(+), 1311 deletions(-) create mode 100644 internal/translator/kiro/claude/kiro_claude_stream_parser.go delete mode 100644 internal/translator/kiro/claude/kiro_websearch_handler.go diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index c360b2de..7bd00205 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -16,6 +16,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig { } } +// resolveKiroAPIRegion determines the AWS region for Kiro API calls. +// Region priority: +// 1. auth.Metadata["api_region"] - explicit API region override +// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource +// 3. kiroDefaultRegion (us-east-1) - fallback +// Note: OIDC "region" is NOT used - it's for token refresh, not API calls +func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string { + if auth == nil || auth.Metadata == nil { + return kiroDefaultRegion + } + // Priority 1: Explicit api_region override + if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { + log.Debugf("kiro: using region %s (source: api_region)", r) + return r + } + // Priority 2: Extract from ProfileARN + if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { + if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { + log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion) + return arnRegion + } + } + // Note: OIDC "region" field is NOT used for API endpoint + // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) + // Using OIDC region for API calls causes DNS failures + log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion) + return kiroDefaultRegion +} + // kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region. // Prefer using buildKiroEndpointConfigs(region) for dynamic region support. var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion) @@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { return kiroEndpointConfigs } - // Determine API region with priority: api_region > profile_arn > region > default - region := kiroDefaultRegion - regionSource := "default" - - if auth.Metadata != nil { - // Priority 1: Explicit api_region override - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - regionSource = "api_region" - } else { - // Priority 2: Extract from ProfileARN - if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" { - if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" { - region = arnRegion - regionSource = "profile_arn" - } - } - // Note: OIDC "region" field is NOT used for API endpoint - // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2) - // Using OIDC region for API calls causes DNS failures - } - } - - log.Debugf("kiro: using region %s (source: %s)", region, regionSource) + // Determine API region using shared resolution logic + region := resolveKiroAPIRegion(auth) // Build endpoint configs for the specified region endpointConfigs := buildKiroEndpointConfigs(region) @@ -520,7 +528,7 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil) case "kiro": - // Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer) + // Body is already in Kiro format — pass through directly log.Debugf("kiro: body already in Kiro format, passing through directly") return body, false default: @@ -640,17 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: rate limiter cleared for token %s", tokenKey) - // Check for pure web_search request - // Route to MCP endpoint instead of normal Kiro API - if kiroclaude.HasWebSearchTool(req.Payload) { - log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint") - return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn) - } - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - defer reporter.trackFailure(ctx, &err) - - // Check if token is expired before making request + // Check if token is expired before making request (covers both normal and web_search paths) if e.isTokenExpired(accessToken) { log.Infof("kiro: access token expired, attempting recovery") @@ -679,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } } + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint") + return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn) + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("kiro") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -1068,17 +1076,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut rateLimiter.WaitForToken(tokenKey) log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) - // Check for pure web_search request - // Route to MCP endpoint instead of normal Kiro API - if kiroclaude.HasWebSearchTool(req.Payload) { - log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") - return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) - } - - reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) - defer reporter.trackFailure(ctx, &err) - - // Check if token is expired before making request + // Check if token is expired before making request (covers both normal and web_search paths) if e.isTokenExpired(accessToken) { log.Infof("kiro: access token expired, attempting recovery before stream request") @@ -1107,6 +1105,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } } + // Check for pure web_search request + // Route to MCP endpoint instead of normal Kiro API + if kiroclaude.HasWebSearchTool(req.Payload) { + log.Infof("kiro: detected pure web_search request, routing to MCP endpoint") + return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn) + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("kiro") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -4114,6 +4122,238 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool { return isExpired } +// ══════════════════════════════════════════════════════════════════════════════ +// Web Search Handler (MCP API) +// ══════════════════════════════════════════════════════════════════════════════ + +// fetchToolDescription caching: +// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time, +// with automatic retry on failure: +// - On failure, fetched stays false so subsequent calls will retry +// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path) +// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(), +// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests. +var ( + toolDescMu sync.Mutex + toolDescFetched atomic.Bool +) + +// fetchToolDescription calls MCP tools/list to get the web_search tool description +// and caches it. Safe to call concurrently — only one goroutine fetches at a time. +// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. +// The httpClient parameter allows reusing a shared pooled HTTP client. +func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) { + // Fast path: already fetched successfully, no lock needed + if toolDescFetched.Load() { + return + } + + toolDescMu.Lock() + defer toolDescMu.Unlock() + + // Double-check after acquiring lock + if toolDescFetched.Load() { + return + } + + handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs) + reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) + log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) + + req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody)) + if err != nil { + log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) + return + } + + // Reuse same headers as callMcpAPI + handler.setMcpHeaders(req) + + resp, err := handler.httpClient.Do(req) + if err != nil { + log.Warnf("kiro/websearch: tools/list request failed: %v", err) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil || resp.StatusCode != http.StatusOK { + log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) + return + } + log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) + + // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} + var result struct { + Result *struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { + log.Warnf("kiro/websearch: failed to parse tools/list response") + return + } + + for _, tool := range result.Result.Tools { + if tool.Name == "web_search" && tool.Description != "" { + kiroclaude.SetWebSearchDescription(tool.Description) + toolDescFetched.Store(true) // success — no more fetches + log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) + return + } + } + + // web_search tool not found in response + log.Warnf("kiro/websearch: web_search tool not found in tools/list response") +} + +// webSearchHandler handles web search requests via Kiro MCP API +type webSearchHandler struct { + ctx context.Context + mcpEndpoint string + httpClient *http.Client + authToken string + auth *cliproxyauth.Auth // for applyDynamicFingerprint + authAttrs map[string]string // optional, for custom headers from auth.Attributes +} + +// newWebSearchHandler creates a new webSearchHandler. +// If httpClient is nil, a default client with 30s timeout is used. +// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. +func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler { + if httpClient == nil { + httpClient = &http.Client{ + Timeout: 30 * time.Second, + } + } + return &webSearchHandler{ + ctx: ctx, + mcpEndpoint: mcpEndpoint, + httpClient: httpClient, + authToken: authToken, + auth: auth, + authAttrs: authAttrs, + } +} + +// setMcpHeaders sets standard MCP API headers on the request, +// aligned with the GAR request pattern. +func (h *webSearchHandler) setMcpHeaders(req *http.Request) { + // 1. Content-Type & Accept (aligned with GAR) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "*/*") + + // 2. Kiro-specific headers (aligned with GAR) + req.Header.Set("x-amzn-kiro-agent-mode", "vibe") + req.Header.Set("x-amzn-codewhisperer-optout", "true") + + // 3. User-Agent: Reuse applyDynamicFingerprint for consistency + applyDynamicFingerprint(req, h.auth) + + // 4. AWS SDK identifiers + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + // 5. Authentication + req.Header.Set("Authorization", "Bearer "+h.authToken) + + // 6. Custom headers from auth attributes + util.ApplyCustomHeadersFromAttrs(req, h.authAttrs) +} + +// mcpMaxRetries is the maximum number of retries for MCP API calls. +const mcpMaxRetries = 2 + +// callMcpAPI calls the Kiro MCP API with the given request. +// Includes retry logic with exponential backoff for retryable errors. +func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) { + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal MCP request: %w", err) + } + log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody)) + + var lastErr error + for attempt := 0; attempt <= mcpMaxRetries; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 10*time.Second { + backoff = 10 * time.Second + } + log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) + select { + case <-h.ctx.Done(): + return nil, h.ctx.Err() + case <-time.After(backoff): + } + } + + req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + h.setMcpHeaders(req) + + resp, err := h.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("MCP API request failed: %w", err) + continue // network error → retry + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + lastErr = fmt.Errorf("failed to read MCP response: %w", err) + continue // read error → retry + } + log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) + + // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) + if resp.StatusCode >= 502 && resp.StatusCode <= 504 { + lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) + } + + var mcpResponse kiroclaude.McpResponse + if err := json.Unmarshal(body, &mcpResponse); err != nil { + return nil, fmt.Errorf("failed to parse MCP response: %w", err) + } + + if mcpResponse.Error != nil { + code := -1 + if mcpResponse.Error.Code != nil { + code = *mcpResponse.Error.Code + } + msg := "Unknown error" + if mcpResponse.Error.Message != nil { + msg = *mcpResponse.Error.Message + } + return nil, fmt.Errorf("MCP error %d: %s", code, msg) + } + + return &mcpResponse, nil + } + + return nil, lastErr +} + +// webSearchAuthAttrs extracts auth attributes for MCP calls. +// Used by handleWebSearch and handleWebSearchStream to pass custom headers. +func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string { + if auth != nil { + return auth.Attributes + } + return nil +} + const maxWebSearchIterations = 5 // handleWebSearchStream handles web_search requests: @@ -4136,58 +4376,63 @@ func (e *KiroExecutor) handleWebSearchStream( return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn) } - // Build MCP endpoint based on region - region := kiroDefaultRegion - if auth != nil && auth.Metadata != nil { - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - } - } - mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) // ── Step 1: tools/list (SYNC) — cache tool description ── { - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) } // Create output channel out := make(chan cliproxyexecutor.StreamChunk) + // Usage reporting: track web search requests like normal streaming requests + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + go func() { + var wsErr error + defer reporter.trackFailure(ctx, &wsErr) defer close(out) - // Send message_start event to client - messageStartEvent := kiroclaude.SseEvent{ - Event: "message_start", - Data: map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": kiroclaude.GenerateMessageID(), - "type": "message", - "role": "assistant", - "model": req.Model, - "content": []interface{}{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]interface{}{ - "input_tokens": len(req.Payload) / 4, - "output_tokens": 0, - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0, - }, - }, - }, + // Estimate input tokens using tokenizer (matching streamToChannel pattern) + var totalUsage usage.Detail + if enc, tokErr := getTokenizer(req.Model); tokErr == nil { + if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 { + totalUsage.InputTokens = inp + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) + } + } else { + totalUsage.InputTokens = int64(len(req.Payload) / 4) } + if totalUsage.InputTokens == 0 && len(req.Payload) > 0 { + totalUsage.InputTokens = 1 + } + var accumulatedOutputLen int + defer func() { + if wsErr != nil { + return // let trackFailure handle failure reporting + } + totalUsage.OutputTokens = int64(accumulatedOutputLen / 4) + if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + reporter.publish(ctx, totalUsage) + }() + + // Send message_start event to client (aligned with streamToChannel pattern) + // Use payloadRequestedModel to return user's original model alias + msgStart := kiroclaude.BuildClaudeMessageStartEvent( + payloadRequestedModel(opts, req.Model), + totalUsage.InputTokens, + ) select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}: } // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ── @@ -4216,14 +4461,10 @@ func (e *KiroExecutor) handleWebSearchStream( // MCP search _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) var searchResults *kiroclaude.WebSearchResults if mcpErr != nil { @@ -4255,8 +4496,9 @@ func (e *KiroExecutor) handleWebSearchStream( currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults) if err != nil { log.Warnf("kiro/websearch: failed to inject tool results: %v", err) + wsErr = fmt.Errorf("failed to inject tool results: %w", err) e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - break + return } // Call GAR with modified Claude payload (full translation pipeline) @@ -4265,8 +4507,9 @@ func (e *KiroExecutor) handleWebSearchStream( kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn) if kiroErr != nil { log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr) + wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr) e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults) - break + return } // Analyze response @@ -4297,12 +4540,14 @@ func (e *KiroExecutor) handleWebSearchStream( if !shouldForward { continue } + accumulatedOutputLen += len(adjusted) select { case <-ctx.Done(): return case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}: } } else { + accumulatedOutputLen += len(chunk) select { case <-ctx.Done(): return @@ -4320,8 +4565,103 @@ func (e *KiroExecutor) handleWebSearchStream( return out, nil } +// handleWebSearch handles web_search requests for non-streaming Execute path. +// Performs MCP search synchronously, injects results into the request payload, +// then calls the normal non-streaming Kiro API path which returns a proper +// Claude JSON response (not SSE chunks). +func (e *KiroExecutor) handleWebSearch( + ctx context.Context, + auth *cliproxyauth.Auth, + req cliproxyexecutor.Request, + opts cliproxyexecutor.Options, + accessToken, profileArn string, +) (cliproxyexecutor.Response, error) { + // Extract search query from Claude Code's web_search tool_use + query := kiroclaude.ExtractSearchQuery(req.Payload) + if query == "" { + log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") + // Fall through to normal non-streaming path + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback) + region := resolveKiroAPIRegion(auth) + mcpEndpoint := kiroclaude.BuildMcpEndpoint(region) + + // Step 1: Fetch/cache tool description (sync) + { + authAttrs := webSearchAuthAttrs(auth) + fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + } + + // Step 2: Perform MCP search + _, mcpRequest := kiroclaude.CreateMcpRequest(query) + + authAttrs := webSearchAuthAttrs(auth) + handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs) + mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest) + + var searchResults *kiroclaude.WebSearchResults + if mcpErr != nil { + log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) + } else { + searchResults = kiroclaude.ParseSearchResults(mcpResponse) + } + + resultCount := 0 + if searchResults != nil { + resultCount = len(searchResults.Results) + } + log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query) + + // Step 3: Replace restrictive web_search tool description (align with streaming path) + simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) + if simplifyErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr) + simplifiedPayload = bytes.Clone(req.Payload) + } + + // Step 4: Inject search tool_use + tool_result into Claude payload + currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) + modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults) + if err != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) + return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) + } + + // Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry) + // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream + // to produce a proper Claude JSON response + modifiedReq := req + modifiedReq.Payload = modifiedPayload + + resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) + if err != nil { + return resp, err + } + + // Step 6: Inject server_tool_use + web_search_tool_result into response + // so Claude Code can display "Did X searches in Ys" + indicators := []kiroclaude.SearchIndicator{ + { + ToolUseID: currentToolUseId, + Query: query, + Results: searchResults, + }, + } + injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) + if injErr != nil { + log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) + } else { + resp.Payload = injectedPayload + } + + return resp, nil +} + // callKiroAndBuffer calls the Kiro API and buffers all response chunks. // Returns the buffered chunks for analysis before forwarding to client. +// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter. func (e *KiroExecutor) callKiroAndBuffer( ctx context.Context, auth *cliproxyauth.Auth, @@ -4338,10 +4678,7 @@ func (e *KiroExecutor) callKiroAndBuffer( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } + tokenKey := getTokenKey(auth) kiroStream, err := e.executeStreamWithRetry( ctx, auth, req, opts, accessToken, effectiveProfileArn, @@ -4367,51 +4704,6 @@ func (e *KiroExecutor) callKiroAndBuffer( return chunks, nil } -// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation). -// Used in the web search loop where the payload is modified directly in Kiro format. -func (e *KiroExecutor) callKiroRawAndBuffer( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, - kiroBody []byte, -) ([][]byte, error) { - kiroModelID := e.mapModelToKiro(req.Model) - isAgentic, isChatOnly := determineAgenticMode(req.Model) - effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } - log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody)) - - kiroFormat := sdktranslator.FromString("kiro") - kiroStream, err := e.executeStreamWithRetry( - ctx, auth, req, opts, accessToken, effectiveProfileArn, - nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, - ) - if err != nil { - return nil, err - } - - // Buffer all chunks - var chunks [][]byte - for chunk := range kiroStream { - if chunk.Err != nil { - return chunks, chunk.Err - } - if len(chunk.Payload) > 0 { - chunks = append(chunks, bytes.Clone(chunk.Payload)) - } - } - - log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks)) - - return chunks, nil -} - // callKiroDirectStream creates a direct streaming channel to Kiro API without search. func (e *KiroExecutor) callKiroDirectStream( ctx context.Context, @@ -4428,18 +4720,22 @@ func (e *KiroExecutor) callKiroDirectStream( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := "" - if auth != nil { - tokenKey = auth.ID - } + tokenKey := getTokenKey(auth) - return e.executeStreamWithRetry( + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + var streamErr error + defer reporter.trackFailure(ctx, &streamErr) + + stream, streamErr := e.executeStreamWithRetry( ctx, auth, req, opts, accessToken, effectiveProfileArn, - nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey, + nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey, ) + return stream, streamErr } // sendFallbackText sends a simple text response when the Kiro API fails during the search loop. +// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment +// with how streamToChannel() uses BuildClaude*Event() functions. func (e *KiroExecutor) sendFallbackText( ctx context.Context, out chan<- cliproxyexecutor.StreamChunk, @@ -4447,182 +4743,14 @@ func (e *KiroExecutor) sendFallbackText( query string, searchResults *kiroclaude.WebSearchResults, ) { - // Generate a simple text summary from search results - summary := kiroclaude.FormatSearchContextPrompt(query, searchResults) - - events := []kiroclaude.SseEvent{ - { - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": contentBlockIndex, - "content_block": map[string]interface{}{ - "type": "text", - "text": "", - }, - }, - }, - { - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": contentBlockIndex, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": summary, - }, - }, - }, - { - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": contentBlockIndex, - }, - }, - } - + events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults) for _, event := range events { select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}: } } - - // Send message_delta with end_turn and message_stop - msgDelta := kiroclaude.SseEvent{ - Event: "message_delta", - Data: map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]interface{}{ - "output_tokens": len(summary) / 4, - }, - }, - } - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}: - } - - msgStop := kiroclaude.SseEvent{ - Event: "message_stop", - Data: map[string]interface{}{ - "type": "message_stop", - }, - } - select { - case <-ctx.Done(): - return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}: - } - -} - -// handleWebSearch handles web_search requests for non-streaming Execute path. -// Performs MCP search synchronously, injects results into the request payload, -// then calls the normal non-streaming Kiro API path which returns a proper -// Claude JSON response (not SSE chunks). -func (e *KiroExecutor) handleWebSearch( - ctx context.Context, - auth *cliproxyauth.Auth, - req cliproxyexecutor.Request, - opts cliproxyexecutor.Options, - accessToken, profileArn string, -) (cliproxyexecutor.Response, error) { - // Extract search query from Claude Code's web_search tool_use - query := kiroclaude.ExtractSearchQuery(req.Payload) - if query == "" { - log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute") - // Fall through to normal non-streaming path - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Build MCP endpoint based on region - region := kiroDefaultRegion - if auth != nil && auth.Metadata != nil { - if r, ok := auth.Metadata["api_region"].(string); ok && r != "" { - region = r - } - } - mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) - - // Step 1: Fetch/cache tool description (sync) - { - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - } - - // Step 2: Perform MCP search - _, mcpRequest := kiroclaude.CreateMcpRequest(query) - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) - var authAttrs map[string]string - if auth != nil { - authAttrs = auth.Attributes - } - handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs) - mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest) - - var searchResults *kiroclaude.WebSearchResults - if mcpErr != nil { - log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr) - } else { - searchResults = kiroclaude.ParseSearchResults(mcpResponse) - } - - resultCount := 0 - if searchResults != nil { - resultCount = len(searchResults.Results) - } - log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query) - - // Step 3: Inject search tool_use + tool_result into Claude payload - currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) - modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults) - if err != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err) - return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn) - } - - // Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry) - // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream - // to produce a proper Claude JSON response - modifiedReq := req - modifiedReq.Payload = modifiedPayload - - resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn) - if err != nil { - return resp, err - } - - // Step 5: Inject server_tool_use + web_search_tool_result into response - // so Claude Code can display "Did X searches in Ys" - indicators := []kiroclaude.SearchIndicator{ - { - ToolUseID: currentToolUseId, - Query: query, - Results: searchResults, - }, - } - injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators) - if injErr != nil { - log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr) - } else { - resp.Payload = injectedPayload - } - - return resp, nil } // executeNonStreamFallback runs the standard non-streaming Execute path for a request. diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go index 84fd6621..ab6f0fce 100644 --- a/internal/translator/kiro/claude/kiro_claude_stream.go +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -183,4 +183,129 @@ func PendingTagSuffix(buffer, tag string) int { } } return 0 -} \ No newline at end of file +} + +// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events +// (server_tool_use + web_search_tool_result) without text summary or message termination. +// These events trigger Claude Code's search indicator UI. +// The caller is responsible for sending message_start before and message_delta/stop after. +func GenerateSearchIndicatorEvents( + query string, + toolUseID string, + searchResults *WebSearchResults, + startIndex int, +) []sseEvent { + events := make([]sseEvent, 0, 4) + + // 1. content_block_start (server_tool_use) + events = append(events, sseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": startIndex, + "content_block": map[string]interface{}{ + "id": toolUseID, + "type": "server_tool_use", + "name": "web_search", + "input": map[string]interface{}{}, + }, + }, + }) + + // 2. content_block_delta (input_json_delta) + inputJSON, _ := json.Marshal(map[string]string{"query": query}) + events = append(events, sseEvent{ + Event: "content_block_delta", + Data: map[string]interface{}{ + "type": "content_block_delta", + "index": startIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), + }, + }, + }) + + // 3. content_block_stop (server_tool_use) + events = append(events, sseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex, + }, + }) + + // 4. content_block_start (web_search_tool_result) + searchContent := make([]map[string]interface{}, 0) + if searchResults != nil { + for _, r := range searchResults.Results { + snippet := "" + if r.Snippet != nil { + snippet = *r.Snippet + } + searchContent = append(searchContent, map[string]interface{}{ + "type": "web_search_result", + "title": r.Title, + "url": r.URL, + "encrypted_content": snippet, + "page_age": nil, + }) + } + } + events = append(events, sseEvent{ + Event: "content_block_start", + Data: map[string]interface{}{ + "type": "content_block_start", + "index": startIndex + 1, + "content_block": map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": toolUseID, + "content": searchContent, + }, + }, + }) + + // 5. content_block_stop (web_search_tool_result) + events = append(events, sseEvent{ + Event: "content_block_stop", + Data: map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex + 1, + }, + }) + + return events +} + +// BuildFallbackTextEvents generates SSE events for a fallback text response +// when the Kiro API fails during the search loop. Uses BuildClaude*Event() +// functions to align with streamToChannel patterns. +// Returns raw SSE byte slices ready to be sent to the client channel. +func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte { + summary := FormatSearchContextPrompt(query, results) + outputTokens := len(summary) / 4 + if len(summary) > 0 && outputTokens == 0 { + outputTokens = 1 + } + + var events [][]byte + + // content_block_start (text) + events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")) + + // content_block_delta (text_delta) + events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex)) + + // content_block_stop + events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex)) + + // message_delta with end_turn + events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{ + OutputTokens: int64(outputTokens), + })) + + // message_stop + events = append(events, BuildClaudeMessageStopOnlyEvent()) + + return events +} diff --git a/internal/translator/kiro/claude/kiro_claude_stream_parser.go b/internal/translator/kiro/claude/kiro_claude_stream_parser.go new file mode 100644 index 00000000..35ae945b --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_stream_parser.go @@ -0,0 +1,350 @@ +package claude + +import ( + "encoding/json" + "strings" + + log "github.com/sirupsen/logrus" +) + +// sseEvent represents a Server-Sent Event +type sseEvent struct { + Event string + Data interface{} +} + +// ToSSEString converts the event to SSE wire format +func (e *sseEvent) ToSSEString() string { + dataBytes, _ := json.Marshal(e.Data) + return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n" +} + +// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset. +// It also suppresses duplicate message_start events (returns shouldForward=false). +// This is used to combine search indicator events (indices 0,1) with Kiro model response events. +// +// The data parameter is a single SSE "data:" line payload (JSON). +// Returns: adjusted data, shouldForward (false = skip this event). +func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) { + if len(data) == 0 { + return data, true + } + + // Quick check: parse the JSON + var event map[string]interface{} + if err := json.Unmarshal(data, &event); err != nil { + // Not valid JSON, pass through + return data, true + } + + eventType, _ := event["type"].(string) + + // Suppress duplicate message_start events + if eventType == "message_start" { + return data, false + } + + // Adjust index for content_block events + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + offset + adjusted, err := json.Marshal(event) + if err != nil { + return data, true + } + return adjusted, true + } + } + + // Pass through all other events unchanged (message_delta, message_stop, ping, etc.) + return data, true +} + +// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs) +// and adjusts content block indices. Suppresses duplicate message_start events. +// Returns the adjusted chunk and whether it should be forwarded. +func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) { + chunkStr := string(chunk) + + // Fast path: if no "data:" prefix, pass through + if !strings.Contains(chunkStr, "data: ") { + return chunk, true + } + + var result strings.Builder + hasContent := false + + lines := strings.Split(chunkStr, "\n") + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + result.WriteString(line + "\n") + hasContent = true + continue + } + + adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset) + if !shouldForward { + // Skip this event and its preceding "event:" line + // Also skip the trailing empty line + continue + } + + result.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + dataPayload := strings.TrimPrefix(lines[i+1], "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err == nil { + if eventType, ok := event["type"].(string); ok && eventType == "message_start" { + // Skip both the event: and data: lines + i++ // skip the data: line too + continue + } + } + } + result.WriteString(line + "\n") + hasContent = true + } else { + result.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if !hasContent { + return nil, false + } + + return []byte(result.String()), true +} + +// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response. +type BufferedStreamResult struct { + // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use") + StopReason string + // WebSearchQuery is the extracted query if the model requested another web_search + WebSearchQuery string + // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults) + WebSearchToolUseId string + // HasWebSearchToolUse indicates whether the model requested web_search + HasWebSearchToolUse bool + // WebSearchToolUseIndex is the content_block index of the web_search tool_use + WebSearchToolUseIndex int +} + +// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use. +// This is used in the search loop to determine if the model wants another search round. +func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { + result := BufferedStreamResult{WebSearchToolUseIndex: -1} + + // Track tool use state across chunks + var currentToolName string + var currentToolIndex int = -1 + var toolInputBuilder strings.Builder + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + for _, line := range lines { + if !strings.HasPrefix(line, "data: ") { + continue + } + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + if dataPayload == "[DONE]" || dataPayload == "" { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + continue + } + + eventType, _ := event["type"].(string) + + switch eventType { + case "message_delta": + // Extract stop_reason from message_delta + if delta, ok := event["delta"].(map[string]interface{}); ok { + if sr, ok := delta["stop_reason"].(string); ok && sr != "" { + result.StopReason = sr + } + } + + case "content_block_start": + // Detect tool_use content blocks + if cb, ok := event["content_block"].(map[string]interface{}); ok { + if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" { + if name, ok := cb["name"].(string); ok { + currentToolName = strings.ToLower(name) + if idx, ok := event["index"].(float64); ok { + currentToolIndex = int(idx) + } + // Capture tool use ID for toolResults handshake + if id, ok := cb["id"].(string); ok { + result.WebSearchToolUseId = id + } + toolInputBuilder.Reset() + } + } + } + + case "content_block_delta": + // Accumulate tool input JSON + if currentToolName != "" { + if delta, ok := event["delta"].(map[string]interface{}); ok { + if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" { + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuilder.WriteString(partial) + } + } + } + } + + case "content_block_stop": + // Finalize tool use detection + if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" { + result.HasWebSearchToolUse = true + result.WebSearchToolUseIndex = currentToolIndex + // Extract query from accumulated input JSON + inputJSON := toolInputBuilder.String() + var input map[string]string + if err := json.Unmarshal([]byte(inputJSON), &input); err == nil { + if q, ok := input["query"]; ok { + result.WebSearchQuery = q + } + } + log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery) + } + currentToolName = "" + currentToolIndex = -1 + toolInputBuilder.Reset() + } + } + } + + return result +} + +// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use +// content blocks. This prevents the client from seeing "Tool use" prompts for web_search +// when the proxy is handling the search loop internally. +// Also suppresses message_start and message_delta/message_stop events since those +// are managed by the outer handleWebSearchStream. +func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte { + var filtered [][]byte + + for _, chunk := range chunks { + chunkStr := string(chunk) + lines := strings.Split(chunkStr, "\n") + + var resultBuilder strings.Builder + hasContent := false + + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "data: ") { + dataPayload := strings.TrimPrefix(line, "data: ") + dataPayload = strings.TrimSpace(dataPayload) + + if dataPayload == "[DONE]" { + // Skip [DONE] — the outer loop manages stream termination + continue + } + + var event map[string]interface{} + if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { + resultBuilder.WriteString(line + "\n") + hasContent = true + continue + } + + eventType, _ := event["type"].(string) + + // Skip message_start (outer loop sends its own) + if eventType == "message_start" { + continue + } + + // Skip message_delta and message_stop (outer loop manages these) + if eventType == "message_delta" || eventType == "message_stop" { + continue + } + + // Check if this event belongs to the web_search tool_use block + if wsToolIndex >= 0 { + if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex { + // Skip events for the web_search tool_use block + continue + } + } + + // Apply index offset for remaining events + if indexOffset > 0 { + switch eventType { + case "content_block_start", "content_block_delta", "content_block_stop": + if idx, ok := event["index"].(float64); ok { + event["index"] = int(idx) + indexOffset + adjusted, err := json.Marshal(event) + if err == nil { + resultBuilder.WriteString("data: " + string(adjusted) + "\n") + hasContent = true + continue + } + } + } + } + + resultBuilder.WriteString(line + "\n") + hasContent = true + } else if strings.HasPrefix(line, "event: ") { + // Check if the next data line will be suppressed + if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { + nextData := strings.TrimPrefix(lines[i+1], "data: ") + nextData = strings.TrimSpace(nextData) + + var nextEvent map[string]interface{} + if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil { + nextType, _ := nextEvent["type"].(string) + if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" { + i++ // skip the data line + continue + } + if wsToolIndex >= 0 { + if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex { + i++ // skip the data line + continue + } + } + } + } + resultBuilder.WriteString(line + "\n") + hasContent = true + } else { + resultBuilder.WriteString(line + "\n") + if strings.TrimSpace(line) != "" { + hasContent = true + } + } + } + + if hasContent { + filtered = append(filtered, []byte(resultBuilder.String())) + } + } + + return filtered +} diff --git a/internal/translator/kiro/claude/kiro_websearch.go b/internal/translator/kiro/claude/kiro_websearch.go index 25be730e..aaf4d375 100644 --- a/internal/translator/kiro/claude/kiro_websearch.go +++ b/internal/translator/kiro/claude/kiro_websearch.go @@ -1,11 +1,14 @@ // Package claude provides web search functionality for Kiro translator. -// This file implements detection and MCP request/response types for web search. +// This file implements detection, MCP request/response types, and pure data +// transformation utilities for web search. SSE event generation, stream analysis, +// and HTTP I/O logic reside in the executor package (kiro_executor.go). package claude import ( "encoding/json" "fmt" "strings" + "sync/atomic" "time" "github.com/google/uuid" @@ -14,6 +17,26 @@ import ( "github.com/tidwall/sjson" ) +// cachedToolDescription stores the dynamically-fetched web_search tool description. +// Written by the executor via SetWebSearchDescription, read by the translator +// when building the remote_web_search tool for Kiro API requests. +var cachedToolDescription atomic.Value // stores string + +// GetWebSearchDescription returns the cached web_search tool description, +// or empty string if not yet fetched. Lock-free via atomic.Value. +func GetWebSearchDescription() string { + if v := cachedToolDescription.Load(); v != nil { + return v.(string) + } + return "" +} + +// SetWebSearchDescription stores the dynamically-fetched web_search tool description. +// Called by the executor after fetching from MCP tools/list. +func SetWebSearchDescription(desc string) { + cachedToolDescription.Store(desc) +} + // McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API type McpRequest struct { ID string `json:"id"` @@ -191,36 +214,11 @@ func CreateMcpRequest(query string) (string, *McpRequest) { return toolUseID, request } -// GenerateMessageID generates a Claude-style message ID -func GenerateMessageID() string { - return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24] -} - // GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID) func GenerateToolUseID() string { return strings.ReplaceAll(uuid.New().String(), "-", "")[:22] } -// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools). -// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays. -func ContainsWebSearchTool(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if !tools.IsArray() { - return false - } - - for _, tool := range tools.Array() { - name := strings.ToLower(tool.Get("name").String()) - toolType := strings.ToLower(tool.Get("type").String()) - - if isWebSearchTool(name, toolType) { - return true - } - } - - return false -} - // ReplaceWebSearchToolDescription replaces the web_search tool description with // a minimal version that allows re-search without the restrictive "do not search // non-coding topics" instruction from the original Kiro tools/list response. @@ -275,48 +273,6 @@ func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) { return result, nil } -// StripWebSearchTool removes web_search tool entries from the request's tools array. -// If the tools array becomes empty after removal, it is removed entirely. -func StripWebSearchTool(body []byte) ([]byte, error) { - tools := gjson.GetBytes(body, "tools") - if !tools.IsArray() { - return body, nil - } - - var filtered []json.RawMessage - for _, tool := range tools.Array() { - name := strings.ToLower(tool.Get("name").String()) - toolType := strings.ToLower(tool.Get("type").String()) - - if !isWebSearchTool(name, toolType) { - filtered = append(filtered, json.RawMessage(tool.Raw)) - } - } - - var result []byte - var err error - - if len(filtered) == 0 { - // Remove tools array entirely - result, err = sjson.DeleteBytes(body, "tools") - if err != nil { - return body, fmt.Errorf("failed to delete tools: %w", err) - } - } else { - // Replace with filtered array - filteredJSON, marshalErr := json.Marshal(filtered) - if marshalErr != nil { - return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr) - } - result, err = sjson.SetRawBytes(body, "tools", filteredJSON) - if err != nil { - return body, fmt.Errorf("failed to set filtered tools: %w", err) - } - } - - return result, nil -} - // FormatSearchContextPrompt formats search results as a structured text block // for injection into the system prompt. func FormatSearchContextPrompt(query string, results *WebSearchResults) string { @@ -365,7 +321,7 @@ func FormatToolResultText(results *WebSearchResults) string { // // This produces the exact same GAR request format as the Kiro IDE (HAR captures). // IMPORTANT: The web_search tool must remain in the "tools" array for this to work. -// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available. +// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description. func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) { var payload map[string]interface{} if err := json.Unmarshal(claudePayload, &payload); err != nil { @@ -512,658 +468,28 @@ type SearchIndicator struct { Results *WebSearchResults } -// ══════════════════════════════════════════════════════════════════════════════ -// SSE Event Generation -// ══════════════════════════════════════════════════════════════════════════════ - -// SseEvent represents a Server-Sent Event -type SseEvent struct { - Event string - Data interface{} +// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region. +// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream. +func BuildMcpEndpoint(region string) string { + return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region) } -// ToSSEString converts the event to SSE wire format -func (e *SseEvent) ToSSEString() string { - dataBytes, _ := json.Marshal(e.Data) - return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes)) -} - -// GenerateWebSearchEvents generates the 11-event SSE sequence for web search. -// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json), -// content_block_stop, content_block_start(web_search_tool_result), content_block_stop, -// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop -func GenerateWebSearchEvents( - model string, - query string, - toolUseID string, - searchResults *WebSearchResults, - inputTokens int, -) []SseEvent { - events := make([]SseEvent, 0, 15) - messageID := GenerateMessageID() - - // 1. message_start - events = append(events, SseEvent{ - Event: "message_start", - Data: map[string]interface{}{ - "type": "message_start", - "message": map[string]interface{}{ - "id": messageID, - "type": "message", - "role": "assistant", - "model": model, - "content": []interface{}{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]interface{}{ - "input_tokens": inputTokens, - "output_tokens": 0, - "cache_creation_input_tokens": 0, - "cache_read_input_tokens": 0, - }, - }, - }, - }) - - // 2. content_block_start (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 0, - "content_block": map[string]interface{}{ - "id": toolUseID, - "type": "server_tool_use", - "name": "web_search", - "input": map[string]interface{}{}, - }, - }, - }) - - // 3. content_block_delta (input_json_delta) - inputJSON, _ := json.Marshal(map[string]string{"query": query}) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": string(inputJSON), - }, - }, - }) - - // 4. content_block_stop (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 0, - }, - }) - - // 5. content_block_start (web_search_tool_result) - searchContent := make([]map[string]interface{}, 0) - if searchResults != nil { - for _, r := range searchResults.Results { - snippet := "" - if r.Snippet != nil { - snippet = *r.Snippet - } - searchContent = append(searchContent, map[string]interface{}{ - "type": "web_search_result", - "title": r.Title, - "url": r.URL, - "encrypted_content": snippet, - "page_age": nil, - }) - } - } - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 1, - "content_block": map[string]interface{}{ - "type": "web_search_tool_result", - "tool_use_id": toolUseID, - "content": searchContent, - }, - }, - }) - - // 6. content_block_stop (web_search_tool_result) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 1, - }, - }) - - // 7. content_block_start (text) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": 2, - "content_block": map[string]interface{}{ - "type": "text", - "text": "", - }, - }, - }) - - // 8. content_block_delta (text_delta) - generate search summary - summary := generateSearchSummary(query, searchResults) - - // Split text into chunks for streaming effect - chunkSize := 100 - runes := []rune(summary) - for i := 0; i < len(runes); i += chunkSize { - end := i + chunkSize - if end > len(runes) { - end = len(runes) - } - chunk := string(runes[i:end]) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": 2, - "delta": map[string]interface{}{ - "type": "text_delta", - "text": chunk, - }, - }, - }) - } - - // 9. content_block_stop (text) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": 2, - }, - }) - - // 10. message_delta - outputTokens := (len(summary) + 3) / 4 // Simple estimation - events = append(events, SseEvent{ - Event: "message_delta", - Data: map[string]interface{}{ - "type": "message_delta", - "delta": map[string]interface{}{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]interface{}{ - "output_tokens": outputTokens, - }, - }, - }) - - // 11. message_stop - events = append(events, SseEvent{ - Event: "message_stop", - Data: map[string]interface{}{ - "type": "message_stop", - }, - }) - - return events -} - -// generateSearchSummary generates a text summary of search results -func generateSearchSummary(query string, results *WebSearchResults) string { - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query)) - - if results != nil && len(results.Results) > 0 { - for i, r := range results.Results { - sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title)) - if r.Snippet != nil { - snippet := *r.Snippet - if len(snippet) > 200 { - snippet = snippet[:200] + "..." - } - sb.WriteString(fmt.Sprintf(" %s\n", snippet)) - } - sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL)) - } - } else { - sb.WriteString("No results found.\n") - } - - sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.") - - return sb.String() -} - -// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events -// (server_tool_use + web_search_tool_result) without text summary or message termination. -// These events trigger Claude Code's search indicator UI. -// The caller is responsible for sending message_start before and message_delta/stop after. -func GenerateSearchIndicatorEvents( - query string, - toolUseID string, - searchResults *WebSearchResults, - startIndex int, -) []SseEvent { - events := make([]SseEvent, 0, 4) - - // 1. content_block_start (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex, - "content_block": map[string]interface{}{ - "id": toolUseID, - "type": "server_tool_use", - "name": "web_search", - "input": map[string]interface{}{}, - }, - }, - }) - - // 2. content_block_delta (input_json_delta) - inputJSON, _ := json.Marshal(map[string]string{"query": query}) - events = append(events, SseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": startIndex, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": string(inputJSON), - }, - }, - }) - - // 3. content_block_stop (server_tool_use) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex, - }, - }) - - // 4. content_block_start (web_search_tool_result) - searchContent := make([]map[string]interface{}, 0) - if searchResults != nil { - for _, r := range searchResults.Results { - snippet := "" - if r.Snippet != nil { - snippet = *r.Snippet - } - searchContent = append(searchContent, map[string]interface{}{ - "type": "web_search_result", - "title": r.Title, - "url": r.URL, - "encrypted_content": snippet, - "page_age": nil, - }) - } - } - events = append(events, SseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex + 1, - "content_block": map[string]interface{}{ - "type": "web_search_tool_result", - "tool_use_id": toolUseID, - "content": searchContent, - }, - }, - }) - - // 5. content_block_stop (web_search_tool_result) - events = append(events, SseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex + 1, - }, - }) - - return events -} - -// ══════════════════════════════════════════════════════════════════════════════ -// Stream Analysis & Manipulation -// ══════════════════════════════════════════════════════════════════════════════ - -// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset. -// It also suppresses duplicate message_start events (returns shouldForward=false). -// This is used to combine search indicator events (indices 0,1) with Kiro model response events. -// -// The data parameter is a single SSE "data:" line payload (JSON). -// Returns: adjusted data, shouldForward (false = skip this event). -func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) { - if len(data) == 0 { - return data, true - } - - // Quick check: parse the JSON - var event map[string]interface{} - if err := json.Unmarshal(data, &event); err != nil { - // Not valid JSON, pass through - return data, true - } - - eventType, _ := event["type"].(string) - - // Suppress duplicate message_start events - if eventType == "message_start" { - return data, false - } - - // Adjust index for content_block events - switch eventType { - case "content_block_start", "content_block_delta", "content_block_stop": - if idx, ok := event["index"].(float64); ok { - event["index"] = int(idx) + offset - adjusted, err := json.Marshal(event) - if err != nil { - return data, true - } - return adjusted, true - } - } - - // Pass through all other events unchanged (message_delta, message_stop, ping, etc.) - return data, true -} - -// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs) -// and adjusts content block indices. Suppresses duplicate message_start events. -// Returns the adjusted chunk and whether it should be forwarded. -func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) { - chunkStr := string(chunk) - - // Fast path: if no "data:" prefix, pass through - if !strings.Contains(chunkStr, "data: ") { - return chunk, true - } - - var result strings.Builder - hasContent := false - - lines := strings.Split(chunkStr, "\n") - for i := 0; i < len(lines); i++ { - line := lines[i] - - if strings.HasPrefix(line, "data: ") { - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - if dataPayload == "[DONE]" { - result.WriteString(line + "\n") - hasContent = true - continue - } - - adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset) - if !shouldForward { - // Skip this event and its preceding "event:" line - // Also skip the trailing empty line - continue - } - - result.WriteString("data: " + string(adjusted) + "\n") - hasContent = true - } else if strings.HasPrefix(line, "event: ") { - // Check if the next data line will be suppressed - if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { - dataPayload := strings.TrimPrefix(lines[i+1], "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err == nil { - if eventType, ok := event["type"].(string); ok && eventType == "message_start" { - // Skip both the event: and data: lines - i++ // skip the data: line too - continue - } - } - } - result.WriteString(line + "\n") - hasContent = true - } else { - result.WriteString(line + "\n") - if strings.TrimSpace(line) != "" { - hasContent = true - } - } - } - - if !hasContent { - return nil, false - } - - return []byte(result.String()), true -} - -// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response. -type BufferedStreamResult struct { - // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use") - StopReason string - // WebSearchQuery is the extracted query if the model requested another web_search - WebSearchQuery string - // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults) - WebSearchToolUseId string - // HasWebSearchToolUse indicates whether the model requested web_search - HasWebSearchToolUse bool - // WebSearchToolUseIndex is the content_block index of the web_search tool_use - WebSearchToolUseIndex int -} - -// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use. -// This is used in the search loop to determine if the model wants another search round. -func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { - result := BufferedStreamResult{WebSearchToolUseIndex: -1} - - // Track tool use state across chunks - var currentToolName string - var currentToolIndex int = -1 - var toolInputBuilder strings.Builder - - for _, chunk := range chunks { - chunkStr := string(chunk) - lines := strings.Split(chunkStr, "\n") - for _, line := range lines { - if !strings.HasPrefix(line, "data: ") { - continue - } - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - if dataPayload == "[DONE]" || dataPayload == "" { - continue - } - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { - continue - } - - eventType, _ := event["type"].(string) - - switch eventType { - case "message_delta": - // Extract stop_reason from message_delta - if delta, ok := event["delta"].(map[string]interface{}); ok { - if sr, ok := delta["stop_reason"].(string); ok && sr != "" { - result.StopReason = sr - } - } - - case "content_block_start": - // Detect tool_use content blocks - if cb, ok := event["content_block"].(map[string]interface{}); ok { - if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" { - if name, ok := cb["name"].(string); ok { - currentToolName = strings.ToLower(name) - if idx, ok := event["index"].(float64); ok { - currentToolIndex = int(idx) - } - // Capture tool use ID for toolResults handshake - if id, ok := cb["id"].(string); ok { - result.WebSearchToolUseId = id - } - toolInputBuilder.Reset() - } - } - } - - case "content_block_delta": - // Accumulate tool input JSON - if currentToolName != "" { - if delta, ok := event["delta"].(map[string]interface{}); ok { - if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" { - if partial, ok := delta["partial_json"].(string); ok { - toolInputBuilder.WriteString(partial) - } - } - } - } - - case "content_block_stop": - // Finalize tool use detection - if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" { - result.HasWebSearchToolUse = true - result.WebSearchToolUseIndex = currentToolIndex - // Extract query from accumulated input JSON - inputJSON := toolInputBuilder.String() - var input map[string]string - if err := json.Unmarshal([]byte(inputJSON), &input); err == nil { - if q, ok := input["query"]; ok { - result.WebSearchQuery = q - } - } - log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery) - } - currentToolName = "" - currentToolIndex = -1 - toolInputBuilder.Reset() - } - } - } - - return result -} - -// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use -// content blocks. This prevents the client from seeing "Tool use" prompts for web_search -// when the proxy is handling the search loop internally. -// Also suppresses message_start and message_delta/message_stop events since those -// are managed by the outer handleWebSearchStream. -func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte { - var filtered [][]byte - - for _, chunk := range chunks { - chunkStr := string(chunk) - lines := strings.Split(chunkStr, "\n") - - var resultBuilder strings.Builder - hasContent := false - - for i := 0; i < len(lines); i++ { - line := lines[i] - - if strings.HasPrefix(line, "data: ") { - dataPayload := strings.TrimPrefix(line, "data: ") - dataPayload = strings.TrimSpace(dataPayload) - - if dataPayload == "[DONE]" { - // Skip [DONE] — the outer loop manages stream termination - continue - } - - var event map[string]interface{} - if err := json.Unmarshal([]byte(dataPayload), &event); err != nil { - resultBuilder.WriteString(line + "\n") - hasContent = true - continue - } - - eventType, _ := event["type"].(string) - - // Skip message_start (outer loop sends its own) - if eventType == "message_start" { - continue - } - - // Skip message_delta and message_stop (outer loop manages these) - if eventType == "message_delta" || eventType == "message_stop" { - continue - } - - // Check if this event belongs to the web_search tool_use block - if wsToolIndex >= 0 { - if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex { - // Skip events for the web_search tool_use block - continue - } - } - - // Apply index offset for remaining events - if indexOffset > 0 { - switch eventType { - case "content_block_start", "content_block_delta", "content_block_stop": - if idx, ok := event["index"].(float64); ok { - event["index"] = int(idx) + indexOffset - adjusted, err := json.Marshal(event) - if err == nil { - resultBuilder.WriteString("data: " + string(adjusted) + "\n") - hasContent = true - continue - } - } - } - } - - resultBuilder.WriteString(line + "\n") - hasContent = true - } else if strings.HasPrefix(line, "event: ") { - // Check if the next data line will be suppressed - if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { - nextData := strings.TrimPrefix(lines[i+1], "data: ") - nextData = strings.TrimSpace(nextData) - - var nextEvent map[string]interface{} - if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil { - nextType, _ := nextEvent["type"].(string) - if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" { - i++ // skip the data line - continue - } - if wsToolIndex >= 0 { - if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex { - i++ // skip the data line - continue - } - } - } - } - resultBuilder.WriteString(line + "\n") - hasContent = true - } else { - resultBuilder.WriteString(line + "\n") - if strings.TrimSpace(line) != "" { - hasContent = true - } - } - } - - if hasContent { - filtered = append(filtered, []byte(resultBuilder.String())) - } - } - - return filtered +// ParseSearchResults extracts WebSearchResults from MCP response +func ParseSearchResults(response *McpResponse) *WebSearchResults { + if response == nil || response.Result == nil || len(response.Result.Content) == 0 { + return nil + } + + content := response.Result.Content[0] + if content.ContentType != "text" { + return nil + } + + var results WebSearchResults + if err := json.Unmarshal([]byte(content.Text), &results); err != nil { + log.Warnf("kiro/websearch: failed to parse search results: %v", err) + return nil + } + + return &results } diff --git a/internal/translator/kiro/claude/kiro_websearch_handler.go b/internal/translator/kiro/claude/kiro_websearch_handler.go deleted file mode 100644 index c64d8eb9..00000000 --- a/internal/translator/kiro/claude/kiro_websearch_handler.go +++ /dev/null @@ -1,270 +0,0 @@ -// Package claude provides web search handler for Kiro translator. -// This file implements the MCP API call and response handling. -package claude - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" - kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -// Cached web_search tool description fetched from MCP tools/list. -// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure: -// - sync.Once prevents race conditions and deduplicates concurrent calls -// - On failure, a fresh sync.Once is swapped in to allow retry on next call -// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls -var ( - cachedToolDescription atomic.Value // stores string - toolDescOnce atomic.Pointer[sync.Once] - fallbackFpOnce sync.Once - fallbackFp *kiroauth.Fingerprint -) - -func init() { - toolDescOnce.Store(&sync.Once{}) -} - -// FetchToolDescription calls MCP tools/list to get the web_search tool description -// and caches it. Safe to call concurrently — only one goroutine fetches at a time. -// If the fetch fails, subsequent calls will retry. On success, no further fetches occur. -// The httpClient parameter allows reusing a shared pooled HTTP client. -func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) { - toolDescOnce.Load().Do(func() { - handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs) - reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`) - log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody)) - - req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody)) - if err != nil { - log.Warnf("kiro/websearch: failed to create tools/list request: %v", err) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - - // Reuse same headers as CallMcpAPI - handler.setMcpHeaders(req) - - resp, err := handler.HTTPClient.Do(req) - if err != nil { - log.Warnf("kiro/websearch: tools/list request failed: %v", err) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil || resp.StatusCode != http.StatusOK { - log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode) - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body)) - - // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}} - var result struct { - Result *struct { - Tools []struct { - Name string `json:"name"` - Description string `json:"description"` - } `json:"tools"` - } `json:"result"` - } - if err := json.Unmarshal(body, &result); err != nil || result.Result == nil { - log.Warnf("kiro/websearch: failed to parse tools/list response") - toolDescOnce.Store(&sync.Once{}) // allow retry - return - } - - for _, tool := range result.Result.Tools { - if tool.Name == "web_search" && tool.Description != "" { - cachedToolDescription.Store(tool.Description) - log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description)) - return // success — sync.Once stays "done", no more fetches - } - } - - // web_search tool not found in response - toolDescOnce.Store(&sync.Once{}) // allow retry - }) -} - -// GetWebSearchDescription returns the cached web_search tool description, -// or empty string if not yet fetched. Lock-free via atomic.Value. -func GetWebSearchDescription() string { - if v := cachedToolDescription.Load(); v != nil { - return v.(string) - } - return "" -} - -// WebSearchHandler handles web search requests via Kiro MCP API -type WebSearchHandler struct { - McpEndpoint string - HTTPClient *http.Client - AuthToken string - Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers - AuthAttrs map[string]string // optional, for custom headers from auth.Attributes -} - -// NewWebSearchHandler creates a new WebSearchHandler. -// If httpClient is nil, a default client with 30s timeout is used. -// If fingerprint is nil, a random one-off fingerprint is generated. -// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse. -func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler { - if httpClient == nil { - httpClient = &http.Client{ - Timeout: 30 * time.Second, - } - } - if fp == nil { - // Use a shared fallback fingerprint for callers without token context - fallbackFpOnce.Do(func() { - mgr := kiroauth.NewFingerprintManager() - fallbackFp = mgr.GetFingerprint("mcp-fallback") - }) - fp = fallbackFp - } - return &WebSearchHandler{ - McpEndpoint: mcpEndpoint, - HTTPClient: httpClient, - AuthToken: authToken, - Fingerprint: fp, - AuthAttrs: authAttrs, - } -} - -// setMcpHeaders sets standard MCP API headers on the request, -// aligned with the GAR request pattern in kiro_executor.go. -func (h *WebSearchHandler) setMcpHeaders(req *http.Request) { - fp := h.Fingerprint - - // 1. Content-Type & Accept (aligned with GAR) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "*/*") - - // 2. Kiro-specific headers (aligned with GAR) - req.Header.Set("x-amzn-kiro-agent-mode", "vibe") - req.Header.Set("x-amzn-codewhisperer-optout", "true") - - // 3. Dynamic fingerprint headers - req.Header.Set("User-Agent", fp.BuildUserAgent()) - req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) - - // 4. AWS SDK identifiers (casing aligned with GAR) - req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") - req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - - // 5. Authentication - req.Header.Set("Authorization", "Bearer "+h.AuthToken) - - // 6. Custom headers from auth attributes - util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs) -} - -// mcpMaxRetries is the maximum number of retries for MCP API calls. -const mcpMaxRetries = 2 - -// CallMcpAPI calls the Kiro MCP API with the given request. -// Includes retry logic with exponential backoff for retryable errors, -// aligned with the GAR request retry pattern. -func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) { - requestBody, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal MCP request: %w", err) - } - log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody)) - - var lastErr error - for attempt := 0; attempt <= mcpMaxRetries; attempt++ { - if attempt > 0 { - backoff := time.Duration(1< 10*time.Second { - backoff = 10 * time.Second - } - log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr) - time.Sleep(backoff) - } - - req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - h.setMcpHeaders(req) - - resp, err := h.HTTPClient.Do(req) - if err != nil { - lastErr = fmt.Errorf("MCP API request failed: %w", err) - continue // network error → retry - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - lastErr = fmt.Errorf("failed to read MCP response: %w", err) - continue // read error → retry - } - log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body)) - - // Retryable HTTP status codes (aligned with GAR: 502, 503, 504) - if resp.StatusCode >= 502 && resp.StatusCode <= 504 { - lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body)) - continue - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body)) - } - - var mcpResponse McpResponse - if err := json.Unmarshal(body, &mcpResponse); err != nil { - return nil, fmt.Errorf("failed to parse MCP response: %w", err) - } - - if mcpResponse.Error != nil { - code := -1 - if mcpResponse.Error.Code != nil { - code = *mcpResponse.Error.Code - } - msg := "Unknown error" - if mcpResponse.Error.Message != nil { - msg = *mcpResponse.Error.Message - } - return nil, fmt.Errorf("MCP error %d: %s", code, msg) - } - - return &mcpResponse, nil - } - - return nil, lastErr -} - -// ParseSearchResults extracts WebSearchResults from MCP response -func ParseSearchResults(response *McpResponse) *WebSearchResults { - if response == nil || response.Result == nil || len(response.Result.Content) == 0 { - return nil - } - - content := response.Result.Content[0] - if content.ContentType != "text" { - return nil - } - - var results WebSearchResults - if err := json.Unmarshal([]byte(content.Text), &results); err != nil { - log.Warnf("kiro/websearch: failed to parse search results: %v", err) - return nil - } - - return &results -} From 2db89211a9b414369dfd383a5460c5103e3debf8 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Thu, 12 Feb 2026 16:10:35 +0800 Subject: [PATCH 3/8] kiro: use payloadRequestedModel for response model name Align Kiro executor with all other executors (Claude, Gemini, OpenAI, etc.) by using payloadRequestedModel(opts, req.Model) instead of req.Model when constructing response model names. This ensures model aliases are correctly reflected in responses: - Execute: BuildClaudeResponse + TranslateNonStream - ExecuteStream: streamToChannel - handleWebSearchStream: BuildClaudeMessageStartEvent - handleWebSearch: via executeNonStreamFallback (automatic) Previously Kiro was the only executor using req.Model directly, which exposed internal routed names instead of the user's alias. --- internal/runtime/executor/kiro_executor.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 7bd00205..c9792903 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -1033,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Build response in Claude format for Kiro translator // stopReason is extracted from upstream response by parseEventStream - kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + requestedModel := payloadRequestedModel(opts, req.Model) + kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason) + out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil } @@ -1431,7 +1432,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // So we always enable thinking parsing for Kiro responses log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) - e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) + e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled) }(httpResp, thinkingEnabled) return out, nil From 5626637fbd1a6f6b1c841cb5002c6c34df960b65 Mon Sep 17 00:00:00 2001 From: Skyuno Date: Fri, 13 Feb 2026 02:25:55 +0800 Subject: [PATCH 4/8] security: remove query content from web search logs to prevent PII leakage - Remove search query from iteration logs (Info level) - Remove query and toolUseId from analysis logs (Info level) - Remove query from non-stream result logs (Info level) - Remove query from tool injection logs (Info level) - Remove query from tool_use detection logs (Debug level) This addresses the security concern raised in PR #226 review about potential PII exposure in search query logs. --- internal/runtime/executor/kiro_executor.go | 10 +++++----- .../kiro/claude/kiro_claude_stream_parser.go | 2 +- internal/translator/kiro/claude/kiro_websearch.go | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index c9792903..9d197769 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -4457,8 +4457,8 @@ func (e *KiroExecutor) handleWebSearchStream( currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID()) for iteration := 0; iteration < maxWebSearchIterations; iteration++ { - log.Infof("kiro/websearch: search iteration %d/%d — query: %s", - iteration+1, maxWebSearchIterations, currentQuery) + log.Infof("kiro/websearch: search iteration %d/%d", + iteration+1, maxWebSearchIterations) // MCP search _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery) @@ -4515,8 +4515,8 @@ func (e *KiroExecutor) handleWebSearchStream( // Analyze response analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks) - log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v, query: %s, toolUseId: %s", - iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse, analysis.WebSearchQuery, analysis.WebSearchToolUseId) + log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v", + iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse) if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations { // Model wants another search @@ -4613,7 +4613,7 @@ func (e *KiroExecutor) handleWebSearch( if searchResults != nil { resultCount = len(searchResults.Results) } - log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query) + log.Infof("kiro/websearch: non-stream: got %d search results", resultCount) // Step 3: Replace restrictive web_search tool description (align with streaming path) simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload)) diff --git a/internal/translator/kiro/claude/kiro_claude_stream_parser.go b/internal/translator/kiro/claude/kiro_claude_stream_parser.go index 35ae945b..275196ac 100644 --- a/internal/translator/kiro/claude/kiro_claude_stream_parser.go +++ b/internal/translator/kiro/claude/kiro_claude_stream_parser.go @@ -226,7 +226,7 @@ func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult { result.WebSearchQuery = q } } - log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery) + log.Debugf("kiro/websearch: detected web_search tool_use") } currentToolName = "" currentToolIndex = -1 diff --git a/internal/translator/kiro/claude/kiro_websearch.go b/internal/translator/kiro/claude/kiro_websearch.go index aaf4d375..b9da3829 100644 --- a/internal/translator/kiro/claude/kiro_websearch.go +++ b/internal/translator/kiro/claude/kiro_websearch.go @@ -388,8 +388,8 @@ Do NOT apologize for bad results without first attempting a re-search. return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err) } - log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, query=%s, messages=%d)", - toolUseId, query, len(messages)) + log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)", + toolUseId, len(messages)) return result, nil } From 632a2fd2f2c9bb0439356dafb63e07e52841d06f Mon Sep 17 00:00:00 2001 From: Skyuno Date: Fri, 13 Feb 2026 02:36:11 +0800 Subject: [PATCH 5/8] refactor: align GenerateSearchIndicatorEvents return type with other event builders Change GenerateSearchIndicatorEvents to return [][]byte instead of []sseEvent for consistency with BuildFallbackTextEvents and other event building functions. Benefits: - Consistent API across all event generation functions - Eliminates intermediate sseEvent type conversion in caller - Simplifies usage by returning ready-to-send SSE byte slices This addresses the code quality feedback from PR #226 review. --- internal/runtime/executor/kiro_executor.go | 2 +- .../kiro/claude/kiro_claude_stream.go | 93 +++++++++---------- 2 files changed, 45 insertions(+), 50 deletions(-) diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 9d197769..41a5830c 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -4487,7 +4487,7 @@ func (e *KiroExecutor) handleWebSearchStream( select { case <-ctx.Done(): return - case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}: + case out <- cliproxyexecutor.StreamChunk{Payload: event}: } } contentBlockIndex += 2 diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go index ab6f0fce..c86b6e02 100644 --- a/internal/translator/kiro/claude/kiro_claude_stream.go +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -194,46 +194,43 @@ func GenerateSearchIndicatorEvents( toolUseID string, searchResults *WebSearchResults, startIndex int, -) []sseEvent { - events := make([]sseEvent, 0, 4) +) [][]byte { + events := make([][]byte, 0, 5) // 1. content_block_start (server_tool_use) - events = append(events, sseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex, - "content_block": map[string]interface{}{ - "id": toolUseID, - "type": "server_tool_use", - "name": "web_search", - "input": map[string]interface{}{}, - }, + event1 := map[string]interface{}{ + "type": "content_block_start", + "index": startIndex, + "content_block": map[string]interface{}{ + "id": toolUseID, + "type": "server_tool_use", + "name": "web_search", + "input": map[string]interface{}{}, }, - }) + } + data1, _ := json.Marshal(event1) + events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n")) // 2. content_block_delta (input_json_delta) inputJSON, _ := json.Marshal(map[string]string{"query": query}) - events = append(events, sseEvent{ - Event: "content_block_delta", - Data: map[string]interface{}{ - "type": "content_block_delta", - "index": startIndex, - "delta": map[string]interface{}{ - "type": "input_json_delta", - "partial_json": string(inputJSON), - }, + event2 := map[string]interface{}{ + "type": "content_block_delta", + "index": startIndex, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": string(inputJSON), }, - }) + } + data2, _ := json.Marshal(event2) + events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n")) // 3. content_block_stop (server_tool_use) - events = append(events, sseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex, - }, - }) + event3 := map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex, + } + data3, _ := json.Marshal(event3) + events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n")) // 4. content_block_start (web_search_tool_result) searchContent := make([]map[string]interface{}, 0) @@ -252,27 +249,25 @@ func GenerateSearchIndicatorEvents( }) } } - events = append(events, sseEvent{ - Event: "content_block_start", - Data: map[string]interface{}{ - "type": "content_block_start", - "index": startIndex + 1, - "content_block": map[string]interface{}{ - "type": "web_search_tool_result", - "tool_use_id": toolUseID, - "content": searchContent, - }, + event4 := map[string]interface{}{ + "type": "content_block_start", + "index": startIndex + 1, + "content_block": map[string]interface{}{ + "type": "web_search_tool_result", + "tool_use_id": toolUseID, + "content": searchContent, }, - }) + } + data4, _ := json.Marshal(event4) + events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n")) // 5. content_block_stop (web_search_tool_result) - events = append(events, sseEvent{ - Event: "content_block_stop", - Data: map[string]interface{}{ - "type": "content_block_stop", - "index": startIndex + 1, - }, - }) + event5 := map[string]interface{}{ + "type": "content_block_stop", + "index": startIndex + 1, + } + data5, _ := json.Marshal(event5) + events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n")) return events } From 6df16bedbafa05ecc74b69a9d9a88f90fb668e5e Mon Sep 17 00:00:00 2001 From: y Date: Sat, 14 Feb 2026 09:40:05 +0800 Subject: [PATCH 6/8] fix: preserve explicitly deleted kiro aliases across config reload (#222) The delete handler now sets the channel value to nil instead of removing the map key, and the sanitization loop preserves nil/empty channel entries as 'disabled' markers. This prevents SanitizeOAuthModelAlias from re-injecting default kiro aliases after a user explicitly deletes them through the management API. --- .../api/handlers/management/config_lists.go | 8 ++-- internal/config/config.go | 8 +++- internal/config/oauth_model_alias_test.go | 44 +++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 5cca03ba..0153a381 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -796,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) { c.JSON(404, gin.H{"error": "channel not found"}) return } - delete(h.cfg.OAuthModelAlias, channel) - if len(h.cfg.OAuthModelAlias) == 0 { - h.cfg.OAuthModelAlias = nil - } + // Set to nil instead of deleting the key so that the "explicitly disabled" + // marker survives config reload and prevents SanitizeOAuthModelAlias from + // re-injecting default aliases (fixes #222). + h.cfg.OAuthModelAlias[channel] = nil h.persist(c) } diff --git a/internal/config/config.go b/internal/config/config.go index 50b3cbd5..88e1c605 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -767,7 +767,13 @@ func (cfg *Config) SanitizeOAuthModelAlias() { out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias)) for rawChannel, aliases := range cfg.OAuthModelAlias { channel := strings.ToLower(strings.TrimSpace(rawChannel)) - if channel == "" || len(aliases) == 0 { + if channel == "" { + continue + } + // Preserve channels that were explicitly set to empty/nil – they act + // as "disabled" markers so default injection won't re-add them (#222). + if len(aliases) == 0 { + out[channel] = nil continue } seenAlias := make(map[string]struct{}, len(aliases)) diff --git a/internal/config/oauth_model_alias_test.go b/internal/config/oauth_model_alias_test.go index 7497eec8..5cf05502 100644 --- a/internal/config/oauth_model_alias_test.go +++ b/internal/config/oauth_model_alias_test.go @@ -128,6 +128,50 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) { } } +func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) { + // When user explicitly deletes kiro aliases (key exists with nil value), + // defaults should NOT be re-injected on subsequent sanitize calls (#222). + cfg := &Config{ + OAuthModelAlias: map[string][]OAuthModelAlias{ + "kiro": nil, // explicitly deleted + "codex": {{Name: "gpt-5", Alias: "g5"}}, + }, + } + + cfg.SanitizeOAuthModelAlias() + + kiroAliases := cfg.OAuthModelAlias["kiro"] + if len(kiroAliases) != 0 { + t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases)) + } + // The key itself must still be present to prevent re-injection on next reload + if _, exists := cfg.OAuthModelAlias["kiro"]; !exists { + t.Fatal("expected kiro key to be preserved as nil marker after sanitization") + } + // Other channels should be unaffected + if len(cfg.OAuthModelAlias["codex"]) != 1 { + t.Fatal("expected codex aliases to be preserved") + } +} + +func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) { + // Same as above but with empty slice instead of nil (PUT with empty body). + cfg := &Config{ + OAuthModelAlias: map[string][]OAuthModelAlias{ + "kiro": {}, // explicitly set to empty + }, + } + + cfg.SanitizeOAuthModelAlias() + + if len(cfg.OAuthModelAlias["kiro"]) != 0 { + t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"])) + } + if _, exists := cfg.OAuthModelAlias["kiro"]; !exists { + t.Fatal("expected kiro key to be preserved") + } +} + func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) { // When OAuthModelAlias is nil, kiro defaults should still be injected cfg := &Config{} From f9a991365f59a7ce28e2202cc8372247435b8778 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 14 Feb 2026 10:56:36 +0800 Subject: [PATCH 7/8] Update internal/runtime/executor/antigravity_executor.go Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- internal/runtime/executor/antigravity_executor.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ee20c519..da82b8d0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1007,10 +1007,14 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { exec := &AntigravityExecutor{cfg: cfg} token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth) - if errToken != nil || token == "" { + if errToken != nil { log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken) return nil } + if token == "" { + log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID) + return nil + } if updatedAuth != nil { auth = updatedAuth } From c4722e42b1518eadb6c51b1f6088f589899c6259 Mon Sep 17 00:00:00 2001 From: ultraplan-bit <248279703+ultraplan-bit@users.noreply.github.com> Date: Sat, 14 Feb 2026 21:58:15 +0800 Subject: [PATCH 8/8] fix(copilot): forward Claude-format tools to Copilot Responses API The normalizeGitHubCopilotResponsesTools filter required type="function", which dropped Claude-format tools (no type field, uses input_schema). Relax the filter to accept tools without a type field and map input_schema to parameters so tools are correctly sent to the upstream API. Co-Authored-By: Claude Opus 4.6 --- .gitignore | 2 +- .../executor/github_copilot_executor.go | 443 +++++++++++++++++- .../executor/github_copilot_executor_test.go | 188 ++++++++ 3 files changed, 626 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 2b9c215a..02493d24 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ cliproxy # Configuration config.yaml .env - +.mcp.json # Generated content bin/* logs/* diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go index 3681faf8..af83ad0c 100644 --- a/internal/runtime/executor/github_copilot_executor.go +++ b/internal/runtime/executor/github_copilot_executor.go @@ -110,7 +110,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -123,6 +123,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = e.normalizeModel(req.Model, body) body = flattenAssistantContent(body) + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", false) @@ -199,7 +205,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth. } var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + converted := "" + if useResponses && from.String() == "claude" { + converted = translateGitHubCopilotResponsesNonStreamToClaude(data) + } else { + converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + } resp = cliproxyexecutor.Response{Payload: []byte(converted)} reporter.ensurePublished(ctx) return resp, nil @@ -216,7 +227,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat - useResponses := useGitHubCopilotResponsesEndpoint(from) + useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model) to := sdktranslator.FromString("openai") if useResponses { to = sdktranslator.FromString("openai-response") @@ -229,6 +240,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = e.normalizeModel(req.Model, body) body = flattenAssistantContent(body) + if useResponses { + body = normalizeGitHubCopilotResponsesInput(body) + body = normalizeGitHubCopilotResponsesTools(body) + } else { + body = normalizeGitHubCopilotChatTools(body) + } requestedModel := payloadRequestedModel(opts, req.Model) body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel) body, _ = sjson.SetBytes(body, "stream", true) @@ -329,7 +346,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox } } - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + var chunks []string + if useResponses && from.String() == "claude" { + chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m) + } else { + chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + } for i := range chunks { out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} } @@ -483,8 +505,12 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte return body } -func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool { - return sourceFormat.String() == "openai-response" +func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool { + if sourceFormat.String() == "openai-response" { + return true + } + baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName) + return strings.Contains(baseModel, "codex") } // flattenAssistantContent converts assistant message content from array format @@ -519,6 +545,411 @@ func flattenAssistantContent(body []byte) []byte { return result } +func normalizeGitHubCopilotChatTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + if tool.Get("type").String() != "function" { + continue + } + filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func normalizeGitHubCopilotResponsesInput(body []byte) []byte { + input := gjson.GetBytes(body, "input") + if input.Exists() { + if input.Type == gjson.String { + return body + } + inputString := input.Raw + if input.Type != gjson.JSON { + inputString = input.String() + } + body, _ = sjson.SetBytes(body, "input", inputString) + return body + } + + var parts []string + if system := gjson.GetBytes(body, "system"); system.Exists() { + if text := strings.TrimSpace(collectTextFromNode(system)); text != "" { + parts = append(parts, text) + } + } + if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() { + for _, msg := range messages.Array() { + if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" { + parts = append(parts, text) + } + } + } + body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n")) + return body +} + +func normalizeGitHubCopilotResponsesTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() { + filtered := "[]" + if tools.IsArray() { + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + // Accept OpenAI format (type="function") and Claude format + // (no type field, but has top-level name + input_schema). + if toolType != "" && toolType != "function" { + continue + } + name := tool.Get("name").String() + if name == "" { + name = tool.Get("function.name").String() + } + if name == "" { + continue + } + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + if desc := tool.Get("description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } else if desc = tool.Get("function.description").String(); desc != "" { + normalized, _ = sjson.Set(normalized, "description", desc) + } + if params := tool.Get("parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("function.parameters"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } else if params = tool.Get("input_schema"); params.Exists() { + normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw) + } + filtered, _ = sjson.SetRaw(filtered, "-1", normalized) + } + } + body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered)) + } + + toolChoice := gjson.GetBytes(body, "tool_choice") + if !toolChoice.Exists() { + return body + } + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "auto", "none", "required": + return body + default: + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body + } + } + if toolChoice.Type == gjson.JSON { + choiceType := toolChoice.Get("type").String() + if choiceType == "function" { + name := toolChoice.Get("name").String() + if name == "" { + name = toolChoice.Get("function.name").String() + } + if name != "" { + normalized := `{"type":"function","name":""}` + normalized, _ = sjson.Set(normalized, "name", name) + body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized)) + return body + } + } + } + body, _ = sjson.SetBytes(body, "tool_choice", "auto") + return body +} + +func collectTextFromNode(node gjson.Result) string { + if !node.Exists() { + return "" + } + if node.Type == gjson.String { + return node.String() + } + if node.IsArray() { + var parts []string + for _, item := range node.Array() { + if item.Type == gjson.String { + if text := item.String(); text != "" { + parts = append(parts, text) + } + continue + } + if text := item.Get("text").String(); text != "" { + parts = append(parts, text) + continue + } + if nested := collectTextFromNode(item.Get("content")); nested != "" { + parts = append(parts, nested) + } + } + return strings.Join(parts, "\n") + } + if node.Type == gjson.JSON { + if text := node.Get("text").String(); text != "" { + return text + } + if nested := collectTextFromNode(node.Get("content")); nested != "" { + return nested + } + return node.Raw + } + return node.String() +} + +type githubCopilotResponsesStreamToolState struct { + Index int + ID string + Name string +} + +type githubCopilotResponsesStreamState struct { + MessageStarted bool + MessageStopSent bool + TextBlockStarted bool + TextBlockIndex int + NextContentIndex int + HasToolUse bool + OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState + ItemIDToTool map[string]*githubCopilotResponsesStreamToolState +} + +func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string { + root := gjson.ParseBytes(data) + out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` + out, _ = sjson.Set(out, "id", root.Get("id").String()) + out, _ = sjson.Set(out, "model", root.Get("model").String()) + + hasToolUse := false + if output := root.Get("output"); output.Exists() && output.IsArray() { + for _, item := range output.Array() { + switch item.Get("type").String() { + case "message": + if content := item.Get("content"); content.Exists() && content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() != "output_text" { + continue + } + text := part.Get("text").String() + if text == "" { + continue + } + block := `{"type":"text","text":""}` + block, _ = sjson.Set(block, "text", text) + out, _ = sjson.SetRaw(out, "content.-1", block) + } + } + case "function_call": + hasToolUse = true + toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` + toolID := item.Get("call_id").String() + if toolID == "" { + toolID = item.Get("id").String() + } + toolUse, _ = sjson.Set(toolUse, "id", toolID) + toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String()) + if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) { + argObj := gjson.Parse(args) + if argObj.IsObject() { + toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw) + } + } + out, _ = sjson.SetRaw(out, "content.-1", toolUse) + } + } + } + + inputTokens := root.Get("usage.input_tokens").Int() + outputTokens := root.Get("usage.output_tokens").Int() + out, _ = sjson.Set(out, "usage.input_tokens", inputTokens) + out, _ = sjson.Set(out, "usage.output_tokens", outputTokens) + if hasToolUse { + out, _ = sjson.Set(out, "stop_reason", "tool_use") + } else { + out, _ = sjson.Set(out, "stop_reason", "end_turn") + } + return out +} + +func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string { + if *param == nil { + *param = &githubCopilotResponsesStreamState{ + TextBlockIndex: -1, + OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState), + ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState), + } + } + state := (*param).(*githubCopilotResponsesStreamState) + + if !bytes.HasPrefix(line, dataTag) { + return nil + } + payload := bytes.TrimSpace(line[5:]) + if bytes.Equal(payload, []byte("[DONE]")) { + return nil + } + if !gjson.ValidBytes(payload) { + return nil + } + + event := gjson.GetBytes(payload, "type").String() + results := make([]string, 0, 4) + ensureMessageStart := func() { + if state.MessageStarted { + return + } + messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}` + messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String()) + messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String()) + results = append(results, "event: message_start\ndata: "+messageStart+"\n\n") + state.MessageStarted = true + } + startTextBlockIfNeeded := func() { + if state.TextBlockStarted { + return + } + if state.TextBlockIndex < 0 { + state.TextBlockIndex = state.NextContentIndex + state.NextContentIndex++ + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + state.TextBlockStarted = true + } + stopTextBlockIfNeeded := func() { + if !state.TextBlockStarted { + return + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + state.TextBlockStarted = false + state.TextBlockIndex = -1 + } + resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState { + if itemID != "" { + if tool, ok := state.ItemIDToTool[itemID]; ok { + return tool + } + } + if tool, ok := state.OutputIndexToTool[outputIndex]; ok { + if itemID != "" { + state.ItemIDToTool[itemID] = tool + } + return tool + } + return nil + } + + switch event { + case "response.created": + ensureMessageStart() + case "response.output_text.delta": + ensureMessageStart() + startTextBlockIfNeeded() + delta := gjson.GetBytes(payload, "delta").String() + if delta != "" { + contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` + contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex) + contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta) + results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n") + } + case "response.output_item.added": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + ensureMessageStart() + stopTextBlockIfNeeded() + state.HasToolUse = true + tool := &githubCopilotResponsesStreamToolState{ + Index: state.NextContentIndex, + ID: gjson.GetBytes(payload, "item.call_id").String(), + Name: gjson.GetBytes(payload, "item.name").String(), + } + if tool.ID == "" { + tool.ID = gjson.GetBytes(payload, "item.id").String() + } + state.NextContentIndex++ + outputIndex := int(gjson.GetBytes(payload, "output_index").Int()) + state.OutputIndexToTool[outputIndex] = tool + if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" { + state.ItemIDToTool[itemID] = tool + } + contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` + contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID) + contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name) + results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n") + case "response.output_item.delta": + item := gjson.GetBytes(payload, "item") + if item.Get("type").String() != "function_call" { + break + } + tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + partial := gjson.GetBytes(payload, "delta").String() + if partial == "" { + partial = item.Get("arguments").String() + } + if partial == "" { + break + } + inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` + inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index) + inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial) + results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n") + case "response.output_item.done": + if gjson.GetBytes(payload, "item.type").String() != "function_call" { + break + } + tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int())) + if tool == nil { + break + } + contentBlockStop := `{"type":"content_block_stop","index":0}` + contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index) + results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n") + case "response.completed": + ensureMessageStart() + stopTextBlockIfNeeded() + if !state.MessageStopSent { + stopReason := "end_turn" + if state.HasToolUse { + stopReason = "tool_use" + } + messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason) + messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int()) + messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int()) + results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n") + results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") + state.MessageStopSent = true + } + } + + return results +} + // isHTTPSuccess checks if the status code indicates success (2xx). func isHTTPSuccess(statusCode int) bool { return statusCode >= 200 && statusCode < 300 diff --git a/internal/runtime/executor/github_copilot_executor_test.go b/internal/runtime/executor/github_copilot_executor_test.go index ef077fd6..2895c8a7 100644 --- a/internal/runtime/executor/github_copilot_executor_test.go +++ b/internal/runtime/executor/github_copilot_executor_test.go @@ -1,8 +1,10 @@ package executor import ( + "strings" "testing" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" ) @@ -52,3 +54,189 @@ func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) { }) } } + +func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") { + t.Fatal("expected openai-response source to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) { + t.Parallel() + if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") { + t.Fatal("expected codex model to use /responses") + } +} + +func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) { + t.Parallel() + if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") { + t.Fatal("expected default openai source with non-codex model to use /chat/completions") + } +} + +func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`) + got := normalizeGitHubCopilotChatTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tool type = %q, want function", tools[0].Get("type").String()) + } +} + +func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`) + got := normalizeGitHubCopilotChatTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) { + t.Parallel() + body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") { + t.Fatalf("input = %q, want merged text", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) { + t.Parallel() + body := []byte(`{"input":{"foo":"bar"}}`) + got := normalizeGitHubCopilotResponsesInput(body) + in := gjson.GetBytes(got, "input") + if in.Type != gjson.String { + t.Fatalf("input type = %v, want string", in.Type) + } + if !strings.Contains(in.String(), "foo") { + t.Fatalf("input = %q, want stringified object", in.String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 1 { + t.Fatalf("tools len = %d, want 1", len(tools)) + } + if tools[0].Get("name").String() != "sum" { + t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be preserved") + } +} + +func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) { + t.Parallel() + body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`) + got := normalizeGitHubCopilotResponsesTools(body) + tools := gjson.GetBytes(got, "tools").Array() + if len(tools) != 2 { + t.Fatalf("tools len = %d, want 2", len(tools)) + } + if tools[0].Get("type").String() != "function" { + t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String()) + } + if tools[0].Get("name").String() != "Bash" { + t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String()) + } + if tools[0].Get("description").String() != "Run commands" { + t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String()) + } + if !tools[0].Get("parameters").Exists() { + t.Fatal("expected parameters to be set from input_schema") + } + if tools[0].Get("parameters.properties.command").Exists() != true { + t.Fatal("expected parameters.properties.command to exist") + } + if tools[1].Get("name").String() != "Read" { + t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice.type").String() != "function" { + t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String()) + } + if gjson.GetBytes(got, "tool_choice.name").String() != "sum" { + t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String()) + } +} + +func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) { + t.Parallel() + body := []byte(`{"tool_choice":{"type":"function"}}`) + got := normalizeGitHubCopilotResponsesTools(body) + if gjson.GetBytes(got, "tool_choice").String() != "auto" { + t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "type").String() != "message" { + t.Fatalf("type = %q, want message", gjson.Get(out, "type").String()) + } + if gjson.Get(out, "content.0.type").String() != "text" { + t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.text").String() != "hello" { + t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String()) + } +} + +func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) { + t.Parallel() + resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`) + out := translateGitHubCopilotResponsesNonStreamToClaude(resp) + if gjson.Get(out, "content.0.type").String() != "tool_use" { + t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String()) + } + if gjson.Get(out, "content.0.name").String() != "sum" { + t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String()) + } + if gjson.Get(out, "stop_reason").String() != "tool_use" { + t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String()) + } +} + +func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) { + t.Parallel() + var param any + + created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m) + if len(created) == 0 || !strings.Contains(created[0], "message_start") { + t.Fatalf("created events = %#v, want message_start", created) + } + + delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m) + joinedDelta := strings.Join(delta, "") + if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") { + t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta) + } + + completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m) + joinedCompleted := strings.Join(completed, "") + if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") { + t.Fatalf("completed events = %#v, want message_delta + message_stop", completed) + } +}