mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
refactor: align web search with executor layer patterns
Consolidate web search handler, SSE event generation, stream analysis, and MCP HTTP I/O into the executor layer. Merge the separate kiro_websearch_handler.go back into kiro_executor.go to align with the single-file-per-executor convention. Translator retains only pure data types, detection, and payload transformation. Key changes: - Move SSE construction (search indicators, fallback text, message_start) from translator to executor, consistent with streamToChannel pattern - Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription) from translator to executor alongside other HTTP I/O - Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication) - Centralize MCP endpoint URL via BuildMcpEndpoint in translator - Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache - Thread context.Context through MCP HTTP calls for cancellation support - Thread usage reporter through all web search API call paths - Add token expiry pre-check before MCP/GAR calls - Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic, ContainsWebSearchTool, StripWebSearchTool)
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
|
||||
// Region priority:
|
||||
// 1. auth.Metadata["api_region"] - explicit API region override
|
||||
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
|
||||
// 3. kiroDefaultRegion (us-east-1) - fallback
|
||||
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
|
||||
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
log.Debugf("kiro: using region %s (source: api_region)", r)
|
||||
return r
|
||||
}
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
|
||||
return arnRegion
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
|
||||
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
|
||||
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
|
||||
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
|
||||
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// Determine API region with priority: api_region > profile_arn > region > default
|
||||
region := kiroDefaultRegion
|
||||
regionSource := "default"
|
||||
|
||||
if auth.Metadata != nil {
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
regionSource = "api_region"
|
||||
} else {
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
region = arnRegion
|
||||
regionSource = "profile_arn"
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
|
||||
// Determine API region using shared resolution logic
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
|
||||
// Build endpoint configs for the specified region
|
||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
||||
@@ -520,7 +528,7 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer)
|
||||
// Body is already in Kiro format — pass through directly
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
@@ -640,17 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
@@ -679,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1068,17 +1076,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
@@ -1107,6 +1105,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -4114,6 +4122,238 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Web Search Handler (MCP API)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// fetchToolDescription caching:
|
||||
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
|
||||
// with automatic retry on failure:
|
||||
// - On failure, fetched stays false so subsequent calls will retry
|
||||
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
|
||||
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
|
||||
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
|
||||
var (
|
||||
toolDescMu sync.Mutex
|
||||
toolDescFetched atomic.Bool
|
||||
)
|
||||
|
||||
// fetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
|
||||
// Fast path: already fetched successfully, no lock needed
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
toolDescMu.Lock()
|
||||
defer toolDescMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as callMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
kiroclaude.SetWebSearchDescription(tool.Description)
|
||||
toolDescFetched.Store(true) // success — no more fetches
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
|
||||
}
|
||||
|
||||
// webSearchHandler handles web search requests via Kiro MCP API
|
||||
type webSearchHandler struct {
|
||||
ctx context.Context
|
||||
mcpEndpoint string
|
||||
httpClient *http.Client
|
||||
authToken string
|
||||
auth *cliproxyauth.Auth // for applyDynamicFingerprint
|
||||
authAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// newWebSearchHandler creates a new webSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return &webSearchHandler{
|
||||
ctx: ctx,
|
||||
mcpEndpoint: mcpEndpoint,
|
||||
httpClient: httpClient,
|
||||
authToken: authToken,
|
||||
auth: auth,
|
||||
authAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern.
|
||||
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
|
||||
applyDynamicFingerprint(req, h.auth)
|
||||
|
||||
// 4. AWS SDK identifiers
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.authToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// callMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors.
|
||||
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return nil, h.ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse kiroclaude.McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// webSearchAuthAttrs extracts auth attributes for MCP calls.
|
||||
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
|
||||
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth != nil {
|
||||
return auth.Attributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
@@ -4136,58 +4376,63 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Usage reporting: track web search requests like normal streaming requests
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
|
||||
go func() {
|
||||
var wsErr error
|
||||
defer reporter.trackFailure(ctx, &wsErr)
|
||||
defer close(out)
|
||||
|
||||
// Send message_start event to client
|
||||
messageStartEvent := kiroclaude.SseEvent{
|
||||
Event: "message_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": kiroclaude.GenerateMessageID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": req.Model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": len(req.Payload) / 4,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
|
||||
var totalUsage usage.Detail
|
||||
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
|
||||
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
}
|
||||
var accumulatedOutputLen int
|
||||
defer func() {
|
||||
if wsErr != nil {
|
||||
return // let trackFailure handle failure reporting
|
||||
}
|
||||
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
|
||||
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
|
||||
totalUsage.OutputTokens = 1
|
||||
}
|
||||
reporter.publish(ctx, totalUsage)
|
||||
}()
|
||||
|
||||
// Send message_start event to client (aligned with streamToChannel pattern)
|
||||
// Use payloadRequestedModel to return user's original model alias
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
|
||||
payloadRequestedModel(opts, req.Model),
|
||||
totalUsage.InputTokens,
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}:
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
@@ -4216,14 +4461,10 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
@@ -4255,8 +4496,9 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
@@ -4265,8 +4507,9 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
@@ -4297,12 +4540,14 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
accumulatedOutputLen += len(adjusted)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
accumulatedOutputLen += len(chunk)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -4320,8 +4565,103 @@ func (e *KiroExecutor) handleWebSearchStream(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
|
||||
|
||||
// Step 3: Replace restrictive web_search tool description (align with streaming path)
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
// Step 4: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 6: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
@@ -4338,10 +4678,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
@@ -4367,51 +4704,6 @@ func (e *KiroExecutor) callKiroAndBuffer(
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation).
|
||||
// Used in the web search loop where the payload is modified directly in Kiro format.
|
||||
func (e *KiroExecutor) callKiroRawAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
kiroBody []byte,
|
||||
) ([][]byte, error) {
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody))
|
||||
|
||||
kiroFormat := sdktranslator.FromString("kiro")
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
@@ -4428,18 +4720,22 @@ func (e *KiroExecutor) callKiroDirectStream(
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
return e.executeStreamWithRetry(
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var streamErr error
|
||||
defer reporter.trackFailure(ctx, &streamErr)
|
||||
|
||||
stream, streamErr := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
return stream, streamErr
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
|
||||
// with how streamToChannel() uses BuildClaude*Event() functions.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
@@ -4447,182 +4743,14 @@ func (e *KiroExecutor) sendFallbackText(
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
// Generate a simple text summary from search results
|
||||
summary := kiroclaude.FormatSearchContextPrompt(query, searchResults)
|
||||
|
||||
events := []kiroclaude.SseEvent{
|
||||
{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": contentBlockIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": contentBlockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": summary,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": contentBlockIndex,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
|
||||
}
|
||||
}
|
||||
|
||||
// Send message_delta with end_turn and message_stop
|
||||
msgDelta := kiroclaude.SseEvent{
|
||||
Event: "message_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"output_tokens": len(summary) / 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}:
|
||||
}
|
||||
|
||||
msgStop := kiroclaude.SseEvent{
|
||||
Event: "message_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
|
||||
|
||||
// Step 3: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 5: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
|
||||
@@ -183,4 +183,129 @@ func PendingTagSuffix(buffer, tag string) int {
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
|
||||
// (server_tool_use + web_search_tool_result) without text summary or message termination.
|
||||
// These events trigger Claude Code's search indicator UI.
|
||||
// The caller is responsible for sending message_start before and message_delta/stop after.
|
||||
func GenerateSearchIndicatorEvents(
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
startIndex int,
|
||||
) []sseEvent {
|
||||
events := make([]sseEvent, 0, 4)
|
||||
|
||||
// 1. content_block_start (server_tool_use)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_stop (server_tool_use)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_stop (web_search_tool_result)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// BuildFallbackTextEvents generates SSE events for a fallback text response
|
||||
// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
|
||||
// functions to align with streamToChannel patterns.
|
||||
// Returns raw SSE byte slices ready to be sent to the client channel.
|
||||
func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
|
||||
summary := FormatSearchContextPrompt(query, results)
|
||||
outputTokens := len(summary) / 4
|
||||
if len(summary) > 0 && outputTokens == 0 {
|
||||
outputTokens = 1
|
||||
}
|
||||
|
||||
var events [][]byte
|
||||
|
||||
// content_block_start (text)
|
||||
events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
|
||||
|
||||
// content_block_delta (text_delta)
|
||||
events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
|
||||
|
||||
// content_block_stop
|
||||
events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
|
||||
|
||||
// message_delta with end_turn
|
||||
events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
|
||||
OutputTokens: int64(outputTokens),
|
||||
}))
|
||||
|
||||
// message_stop
|
||||
events = append(events, BuildClaudeMessageStopOnlyEvent())
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// sseEvent represents a Server-Sent Event
|
||||
type sseEvent struct {
|
||||
Event string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// ToSSEString converts the event to SSE wire format
|
||||
func (e *sseEvent) ToSSEString() string {
|
||||
dataBytes, _ := json.Marshal(e.Data)
|
||||
return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
|
||||
}
|
||||
|
||||
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
|
||||
// It also suppresses duplicate message_start events (returns shouldForward=false).
|
||||
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
|
||||
//
|
||||
// The data parameter is a single SSE "data:" line payload (JSON).
|
||||
// Returns: adjusted data, shouldForward (false = skip this event).
|
||||
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
|
||||
if len(data) == 0 {
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Quick check: parse the JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
// Not valid JSON, pass through
|
||||
return data, true
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Suppress duplicate message_start events
|
||||
if eventType == "message_start" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// Adjust index for content_block events
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + offset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return data, true
|
||||
}
|
||||
return adjusted, true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
|
||||
return data, true
|
||||
}
|
||||
|
||||
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
|
||||
// and adjusts content block indices. Suppresses duplicate message_start events.
|
||||
// Returns the adjusted chunk and whether it should be forwarded.
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
chunkStr := string(chunk)
|
||||
|
||||
// Fast path: if no "data:" prefix, pass through
|
||||
if !strings.Contains(chunkStr, "data: ") {
|
||||
return chunk, true
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
hasContent := false
|
||||
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
|
||||
if !shouldForward {
|
||||
// Skip this event and its preceding "event:" line
|
||||
// Also skip the trailing empty line
|
||||
continue
|
||||
}
|
||||
|
||||
result.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
|
||||
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
|
||||
// Skip both the event: and data: lines
|
||||
i++ // skip the data: line too
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
result.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []byte(result.String()), true
|
||||
}
|
||||
|
||||
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
|
||||
type BufferedStreamResult struct {
|
||||
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
|
||||
StopReason string
|
||||
// WebSearchQuery is the extracted query if the model requested another web_search
|
||||
WebSearchQuery string
|
||||
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
|
||||
WebSearchToolUseId string
|
||||
// HasWebSearchToolUse indicates whether the model requested web_search
|
||||
HasWebSearchToolUse bool
|
||||
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
|
||||
// This is used in the search loop to determine if the model wants another search round.
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
|
||||
// Track tool use state across chunks
|
||||
var currentToolName string
|
||||
var currentToolIndex int = -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
if dataPayload == "[DONE]" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "message_delta":
|
||||
// Extract stop_reason from message_delta
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
|
||||
result.StopReason = sr
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
// Detect tool_use content blocks
|
||||
if cb, ok := event["content_block"].(map[string]interface{}); ok {
|
||||
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
|
||||
if name, ok := cb["name"].(string); ok {
|
||||
currentToolName = strings.ToLower(name)
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Accumulate tool input JSON
|
||||
if currentToolName != "" {
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize tool use detection
|
||||
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
// Extract query from accumulated input JSON
|
||||
inputJSON := toolInputBuilder.String()
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
|
||||
if q, ok := input["query"]; ok {
|
||||
result.WebSearchQuery = q
|
||||
}
|
||||
}
|
||||
log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery)
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
|
||||
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
|
||||
// when the proxy is handling the search loop internally.
|
||||
// Also suppresses message_start and message_delta/message_stop events since those
|
||||
// are managed by the outer handleWebSearchStream.
|
||||
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
|
||||
var filtered [][]byte
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
// Skip [DONE] — the outer loop manages stream termination
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Skip message_start (outer loop sends its own)
|
||||
if eventType == "message_start" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip message_delta and message_stop (outer loop manages these)
|
||||
if eventType == "message_delta" || eventType == "message_stop" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this event belongs to the web_search tool_use block
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
// Skip events for the web_search tool_use block
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Apply index offset for remaining events
|
||||
if indexOffset > 0 {
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err == nil {
|
||||
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
nextData := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
nextData = strings.TrimSpace(nextData)
|
||||
|
||||
var nextEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
|
||||
nextType, _ := nextEvent["type"].(string)
|
||||
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
filtered = append(filtered, []byte(resultBuilder.String()))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
// Package claude provides web search functionality for Kiro translator.
|
||||
// This file implements detection and MCP request/response types for web search.
|
||||
// This file implements detection, MCP request/response types, and pure data
|
||||
// transformation utilities for web search. SSE event generation, stream analysis,
|
||||
// and HTTP I/O logic reside in the executor package (kiro_executor.go).
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,6 +17,26 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// cachedToolDescription stores the dynamically-fetched web_search tool description.
|
||||
// Written by the executor via SetWebSearchDescription, read by the translator
|
||||
// when building the remote_web_search tool for Kiro API requests.
|
||||
var cachedToolDescription atomic.Value // stores string
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
|
||||
// Called by the executor after fetching from MCP tools/list.
|
||||
func SetWebSearchDescription(desc string) {
|
||||
cachedToolDescription.Store(desc)
|
||||
}
|
||||
|
||||
// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
|
||||
type McpRequest struct {
|
||||
ID string `json:"id"`
|
||||
@@ -191,36 +214,11 @@ func CreateMcpRequest(query string) (string, *McpRequest) {
|
||||
return toolUseID, request
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a Claude-style message ID
|
||||
func GenerateMessageID() string {
|
||||
return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24]
|
||||
}
|
||||
|
||||
// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
|
||||
}
|
||||
|
||||
// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools).
|
||||
// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays.
|
||||
func ContainsWebSearchTool(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
if isWebSearchTool(name, toolType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ReplaceWebSearchToolDescription replaces the web_search tool description with
|
||||
// a minimal version that allows re-search without the restrictive "do not search
|
||||
// non-coding topics" instruction from the original Kiro tools/list response.
|
||||
@@ -275,48 +273,6 @@ func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// StripWebSearchTool removes web_search tool entries from the request's tools array.
|
||||
// If the tools array becomes empty after removal, it is removed entirely.
|
||||
func StripWebSearchTool(body []byte) ([]byte, error) {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var filtered []json.RawMessage
|
||||
for _, tool := range tools.Array() {
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
if !isWebSearchTool(name, toolType) {
|
||||
filtered = append(filtered, json.RawMessage(tool.Raw))
|
||||
}
|
||||
}
|
||||
|
||||
var result []byte
|
||||
var err error
|
||||
|
||||
if len(filtered) == 0 {
|
||||
// Remove tools array entirely
|
||||
result, err = sjson.DeleteBytes(body, "tools")
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to delete tools: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Replace with filtered array
|
||||
filteredJSON, marshalErr := json.Marshal(filtered)
|
||||
if marshalErr != nil {
|
||||
return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr)
|
||||
}
|
||||
result, err = sjson.SetRawBytes(body, "tools", filteredJSON)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to set filtered tools: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FormatSearchContextPrompt formats search results as a structured text block
|
||||
// for injection into the system prompt.
|
||||
func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
|
||||
@@ -365,7 +321,7 @@ func FormatToolResultText(results *WebSearchResults) string {
|
||||
//
|
||||
// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
|
||||
// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
|
||||
// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available.
|
||||
// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
@@ -512,658 +468,28 @@ type SearchIndicator struct {
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// SSE Event Generation
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// SseEvent represents a Server-Sent Event
|
||||
type SseEvent struct {
|
||||
Event string
|
||||
Data interface{}
|
||||
// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
|
||||
// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
// ToSSEString converts the event to SSE wire format
|
||||
func (e *SseEvent) ToSSEString() string {
|
||||
dataBytes, _ := json.Marshal(e.Data)
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes))
|
||||
}
|
||||
|
||||
// GenerateWebSearchEvents generates the 11-event SSE sequence for web search.
|
||||
// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json),
|
||||
// content_block_stop, content_block_start(web_search_tool_result), content_block_stop,
|
||||
// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop
|
||||
func GenerateWebSearchEvents(
|
||||
model string,
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
inputTokens int,
|
||||
) []SseEvent {
|
||||
events := make([]SseEvent, 0, 15)
|
||||
messageID := GenerateMessageID()
|
||||
|
||||
// 1. message_start
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_start (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_stop (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 6. content_block_stop (web_search_tool_result)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 1,
|
||||
},
|
||||
})
|
||||
|
||||
// 7. content_block_start (text)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 8. content_block_delta (text_delta) - generate search summary
|
||||
summary := generateSearchSummary(query, searchResults)
|
||||
|
||||
// Split text into chunks for streaming effect
|
||||
chunkSize := 100
|
||||
runes := []rune(summary)
|
||||
for i := 0; i < len(runes); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunk := string(runes[i:end])
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 2,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": chunk,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 9. content_block_stop (text)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 2,
|
||||
},
|
||||
})
|
||||
|
||||
// 10. message_delta
|
||||
outputTokens := (len(summary) + 3) / 4 // Simple estimation
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 11. message_stop
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateSearchSummary generates a text summary of search results
|
||||
func generateSearchSummary(query string, results *WebSearchResults) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query))
|
||||
|
||||
if results != nil && len(results.Results) > 0 {
|
||||
for i, r := range results.Results {
|
||||
sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title))
|
||||
if r.Snippet != nil {
|
||||
snippet := *r.Snippet
|
||||
if len(snippet) > 200 {
|
||||
snippet = snippet[:200] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", snippet))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("No results found.\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
|
||||
// (server_tool_use + web_search_tool_result) without text summary or message termination.
|
||||
// These events trigger Claude Code's search indicator UI.
|
||||
// The caller is responsible for sending message_start before and message_delta/stop after.
|
||||
func GenerateSearchIndicatorEvents(
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
startIndex int,
|
||||
) []SseEvent {
|
||||
events := make([]SseEvent, 0, 4)
|
||||
|
||||
// 1. content_block_start (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_stop (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_stop (web_search_tool_result)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Stream Analysis & Manipulation
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
|
||||
// It also suppresses duplicate message_start events (returns shouldForward=false).
|
||||
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
|
||||
//
|
||||
// The data parameter is a single SSE "data:" line payload (JSON).
|
||||
// Returns: adjusted data, shouldForward (false = skip this event).
|
||||
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
|
||||
if len(data) == 0 {
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Quick check: parse the JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
// Not valid JSON, pass through
|
||||
return data, true
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Suppress duplicate message_start events
|
||||
if eventType == "message_start" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// Adjust index for content_block events
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + offset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return data, true
|
||||
}
|
||||
return adjusted, true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
|
||||
return data, true
|
||||
}
|
||||
|
||||
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
|
||||
// and adjusts content block indices. Suppresses duplicate message_start events.
|
||||
// Returns the adjusted chunk and whether it should be forwarded.
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
chunkStr := string(chunk)
|
||||
|
||||
// Fast path: if no "data:" prefix, pass through
|
||||
if !strings.Contains(chunkStr, "data: ") {
|
||||
return chunk, true
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
hasContent := false
|
||||
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
|
||||
if !shouldForward {
|
||||
// Skip this event and its preceding "event:" line
|
||||
// Also skip the trailing empty line
|
||||
continue
|
||||
}
|
||||
|
||||
result.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
|
||||
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
|
||||
// Skip both the event: and data: lines
|
||||
i++ // skip the data: line too
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
result.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []byte(result.String()), true
|
||||
}
|
||||
|
||||
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
|
||||
type BufferedStreamResult struct {
|
||||
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
|
||||
StopReason string
|
||||
// WebSearchQuery is the extracted query if the model requested another web_search
|
||||
WebSearchQuery string
|
||||
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
|
||||
WebSearchToolUseId string
|
||||
// HasWebSearchToolUse indicates whether the model requested web_search
|
||||
HasWebSearchToolUse bool
|
||||
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
|
||||
// This is used in the search loop to determine if the model wants another search round.
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
|
||||
// Track tool use state across chunks
|
||||
var currentToolName string
|
||||
var currentToolIndex int = -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
if dataPayload == "[DONE]" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "message_delta":
|
||||
// Extract stop_reason from message_delta
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
|
||||
result.StopReason = sr
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
// Detect tool_use content blocks
|
||||
if cb, ok := event["content_block"].(map[string]interface{}); ok {
|
||||
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
|
||||
if name, ok := cb["name"].(string); ok {
|
||||
currentToolName = strings.ToLower(name)
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Accumulate tool input JSON
|
||||
if currentToolName != "" {
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize tool use detection
|
||||
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
// Extract query from accumulated input JSON
|
||||
inputJSON := toolInputBuilder.String()
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
|
||||
if q, ok := input["query"]; ok {
|
||||
result.WebSearchQuery = q
|
||||
}
|
||||
}
|
||||
log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery)
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
|
||||
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
|
||||
// when the proxy is handling the search loop internally.
|
||||
// Also suppresses message_start and message_delta/message_stop events since those
|
||||
// are managed by the outer handleWebSearchStream.
|
||||
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
|
||||
var filtered [][]byte
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
// Skip [DONE] — the outer loop manages stream termination
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Skip message_start (outer loop sends its own)
|
||||
if eventType == "message_start" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip message_delta and message_stop (outer loop manages these)
|
||||
if eventType == "message_delta" || eventType == "message_stop" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this event belongs to the web_search tool_use block
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
// Skip events for the web_search tool_use block
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Apply index offset for remaining events
|
||||
if indexOffset > 0 {
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err == nil {
|
||||
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
nextData := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
nextData = strings.TrimSpace(nextData)
|
||||
|
||||
var nextEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
|
||||
nextType, _ := nextEvent["type"].(string)
|
||||
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
filtered = append(filtered, []byte(resultBuilder.String()))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
// Package claude provides web search handler for Kiro translator.
|
||||
// This file implements the MCP API call and response handling.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cached web_search tool description fetched from MCP tools/list.
|
||||
// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure:
|
||||
// - sync.Once prevents race conditions and deduplicates concurrent calls
|
||||
// - On failure, a fresh sync.Once is swapped in to allow retry on next call
|
||||
// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls
|
||||
var (
|
||||
cachedToolDescription atomic.Value // stores string
|
||||
toolDescOnce atomic.Pointer[sync.Once]
|
||||
fallbackFpOnce sync.Once
|
||||
fallbackFp *kiroauth.Fingerprint
|
||||
)
|
||||
|
||||
func init() {
|
||||
toolDescOnce.Store(&sync.Once{})
|
||||
}
|
||||
|
||||
// FetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) {
|
||||
toolDescOnce.Load().Do(func() {
|
||||
handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as CallMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
cachedToolDescription.Store(tool.Description)
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return // success — sync.Once stays "done", no more fetches
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
})
|
||||
}
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// WebSearchHandler handles web search requests via Kiro MCP API
|
||||
type WebSearchHandler struct {
|
||||
McpEndpoint string
|
||||
HTTPClient *http.Client
|
||||
AuthToken string
|
||||
Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers
|
||||
AuthAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// NewWebSearchHandler creates a new WebSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// If fingerprint is nil, a random one-off fingerprint is generated.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
if fp == nil {
|
||||
// Use a shared fallback fingerprint for callers without token context
|
||||
fallbackFpOnce.Do(func() {
|
||||
mgr := kiroauth.NewFingerprintManager()
|
||||
fallbackFp = mgr.GetFingerprint("mcp-fallback")
|
||||
})
|
||||
fp = fallbackFp
|
||||
}
|
||||
return &WebSearchHandler{
|
||||
McpEndpoint: mcpEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
AuthToken: authToken,
|
||||
Fingerprint: fp,
|
||||
AuthAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern in kiro_executor.go.
|
||||
func (h *WebSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
fp := h.Fingerprint
|
||||
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. Dynamic fingerprint headers
|
||||
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
||||
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
||||
|
||||
// 4. AWS SDK identifiers (casing aligned with GAR)
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.AuthToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// CallMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors,
|
||||
// aligned with the GAR request retry pattern.
|
||||
func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
Reference in New Issue
Block a user