mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-21 15:11:48 +00:00
refactor: align web search with executor layer patterns
Consolidate web search handler, SSE event generation, stream analysis, and MCP HTTP I/O into the executor layer. Merge the separate kiro_websearch_handler.go back into kiro_executor.go to align with the single-file-per-executor convention. Translator retains only pure data types, detection, and payload transformation. Key changes: - Move SSE construction (search indicators, fallback text, message_start) from translator to executor, consistent with streamToChannel pattern - Move MCP handler (callMcpAPI, setMcpHeaders, fetchToolDescription) from translator to executor alongside other HTTP I/O - Reuse applyDynamicFingerprint for MCP UA headers (eliminate duplication) - Centralize MCP endpoint URL via BuildMcpEndpoint in translator - Add atomic Set/GetWebSearchDescription for cross-layer tool desc cache - Thread context.Context through MCP HTTP calls for cancellation support - Thread usage reporter through all web search API call paths - Add token expiry pre-check before MCP/GAR calls - Clean up dead code (GenerateMessageID, webSearchAuthContext fp logic, ContainsWebSearchTool, StripWebSearchTool)
This commit is contained in:
@@ -183,4 +183,129 @@ func PendingTagSuffix(buffer, tag string) int {
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
|
||||
// (server_tool_use + web_search_tool_result) without text summary or message termination.
|
||||
// These events trigger Claude Code's search indicator UI.
|
||||
// The caller is responsible for sending message_start before and message_delta/stop after.
|
||||
func GenerateSearchIndicatorEvents(
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
startIndex int,
|
||||
) []sseEvent {
|
||||
events := make([]sseEvent, 0, 4)
|
||||
|
||||
// 1. content_block_start (server_tool_use)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_stop (server_tool_use)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_stop (web_search_tool_result)
|
||||
events = append(events, sseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// BuildFallbackTextEvents generates SSE events for a fallback text response
|
||||
// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
|
||||
// functions to align with streamToChannel patterns.
|
||||
// Returns raw SSE byte slices ready to be sent to the client channel.
|
||||
func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
|
||||
summary := FormatSearchContextPrompt(query, results)
|
||||
outputTokens := len(summary) / 4
|
||||
if len(summary) > 0 && outputTokens == 0 {
|
||||
outputTokens = 1
|
||||
}
|
||||
|
||||
var events [][]byte
|
||||
|
||||
// content_block_start (text)
|
||||
events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
|
||||
|
||||
// content_block_delta (text_delta)
|
||||
events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
|
||||
|
||||
// content_block_stop
|
||||
events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
|
||||
|
||||
// message_delta with end_turn
|
||||
events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
|
||||
OutputTokens: int64(outputTokens),
|
||||
}))
|
||||
|
||||
// message_stop
|
||||
events = append(events, BuildClaudeMessageStopOnlyEvent())
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// sseEvent represents a Server-Sent Event
|
||||
type sseEvent struct {
|
||||
Event string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// ToSSEString converts the event to SSE wire format
|
||||
func (e *sseEvent) ToSSEString() string {
|
||||
dataBytes, _ := json.Marshal(e.Data)
|
||||
return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
|
||||
}
|
||||
|
||||
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
|
||||
// It also suppresses duplicate message_start events (returns shouldForward=false).
|
||||
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
|
||||
//
|
||||
// The data parameter is a single SSE "data:" line payload (JSON).
|
||||
// Returns: adjusted data, shouldForward (false = skip this event).
|
||||
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
|
||||
if len(data) == 0 {
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Quick check: parse the JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
// Not valid JSON, pass through
|
||||
return data, true
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Suppress duplicate message_start events
|
||||
if eventType == "message_start" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// Adjust index for content_block events
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + offset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return data, true
|
||||
}
|
||||
return adjusted, true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
|
||||
return data, true
|
||||
}
|
||||
|
||||
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
|
||||
// and adjusts content block indices. Suppresses duplicate message_start events.
|
||||
// Returns the adjusted chunk and whether it should be forwarded.
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
chunkStr := string(chunk)
|
||||
|
||||
// Fast path: if no "data:" prefix, pass through
|
||||
if !strings.Contains(chunkStr, "data: ") {
|
||||
return chunk, true
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
hasContent := false
|
||||
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
|
||||
if !shouldForward {
|
||||
// Skip this event and its preceding "event:" line
|
||||
// Also skip the trailing empty line
|
||||
continue
|
||||
}
|
||||
|
||||
result.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
|
||||
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
|
||||
// Skip both the event: and data: lines
|
||||
i++ // skip the data: line too
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
result.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []byte(result.String()), true
|
||||
}
|
||||
|
||||
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
|
||||
type BufferedStreamResult struct {
|
||||
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
|
||||
StopReason string
|
||||
// WebSearchQuery is the extracted query if the model requested another web_search
|
||||
WebSearchQuery string
|
||||
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
|
||||
WebSearchToolUseId string
|
||||
// HasWebSearchToolUse indicates whether the model requested web_search
|
||||
HasWebSearchToolUse bool
|
||||
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
|
||||
// This is used in the search loop to determine if the model wants another search round.
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
|
||||
// Track tool use state across chunks
|
||||
var currentToolName string
|
||||
var currentToolIndex int = -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
if dataPayload == "[DONE]" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "message_delta":
|
||||
// Extract stop_reason from message_delta
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
|
||||
result.StopReason = sr
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
// Detect tool_use content blocks
|
||||
if cb, ok := event["content_block"].(map[string]interface{}); ok {
|
||||
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
|
||||
if name, ok := cb["name"].(string); ok {
|
||||
currentToolName = strings.ToLower(name)
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Accumulate tool input JSON
|
||||
if currentToolName != "" {
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize tool use detection
|
||||
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
// Extract query from accumulated input JSON
|
||||
inputJSON := toolInputBuilder.String()
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
|
||||
if q, ok := input["query"]; ok {
|
||||
result.WebSearchQuery = q
|
||||
}
|
||||
}
|
||||
log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery)
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
|
||||
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
|
||||
// when the proxy is handling the search loop internally.
|
||||
// Also suppresses message_start and message_delta/message_stop events since those
|
||||
// are managed by the outer handleWebSearchStream.
|
||||
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
|
||||
var filtered [][]byte
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
// Skip [DONE] — the outer loop manages stream termination
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Skip message_start (outer loop sends its own)
|
||||
if eventType == "message_start" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip message_delta and message_stop (outer loop manages these)
|
||||
if eventType == "message_delta" || eventType == "message_stop" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this event belongs to the web_search tool_use block
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
// Skip events for the web_search tool_use block
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Apply index offset for remaining events
|
||||
if indexOffset > 0 {
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err == nil {
|
||||
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
nextData := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
nextData = strings.TrimSpace(nextData)
|
||||
|
||||
var nextEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
|
||||
nextType, _ := nextEvent["type"].(string)
|
||||
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
filtered = append(filtered, []byte(resultBuilder.String()))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
// Package claude provides web search functionality for Kiro translator.
|
||||
// This file implements detection and MCP request/response types for web search.
|
||||
// This file implements detection, MCP request/response types, and pure data
|
||||
// transformation utilities for web search. SSE event generation, stream analysis,
|
||||
// and HTTP I/O logic reside in the executor package (kiro_executor.go).
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,6 +17,26 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// cachedToolDescription stores the dynamically-fetched web_search tool description.
|
||||
// Written by the executor via SetWebSearchDescription, read by the translator
|
||||
// when building the remote_web_search tool for Kiro API requests.
|
||||
var cachedToolDescription atomic.Value // stores string
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
|
||||
// Called by the executor after fetching from MCP tools/list.
|
||||
func SetWebSearchDescription(desc string) {
|
||||
cachedToolDescription.Store(desc)
|
||||
}
|
||||
|
||||
// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
|
||||
type McpRequest struct {
|
||||
ID string `json:"id"`
|
||||
@@ -191,36 +214,11 @@ func CreateMcpRequest(query string) (string, *McpRequest) {
|
||||
return toolUseID, request
|
||||
}
|
||||
|
||||
// GenerateMessageID generates a Claude-style message ID
|
||||
func GenerateMessageID() string {
|
||||
return "msg_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:24]
|
||||
}
|
||||
|
||||
// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
|
||||
}
|
||||
|
||||
// ContainsWebSearchTool checks if the request contains a web_search tool (among any tools).
|
||||
// Unlike HasWebSearchTool, this detects web_search even in mixed-tool arrays.
|
||||
func ContainsWebSearchTool(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
if isWebSearchTool(name, toolType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ReplaceWebSearchToolDescription replaces the web_search tool description with
|
||||
// a minimal version that allows re-search without the restrictive "do not search
|
||||
// non-coding topics" instruction from the original Kiro tools/list response.
|
||||
@@ -275,48 +273,6 @@ func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// StripWebSearchTool removes web_search tool entries from the request's tools array.
|
||||
// If the tools array becomes empty after removal, it is removed entirely.
|
||||
func StripWebSearchTool(body []byte) ([]byte, error) {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var filtered []json.RawMessage
|
||||
for _, tool := range tools.Array() {
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
if !isWebSearchTool(name, toolType) {
|
||||
filtered = append(filtered, json.RawMessage(tool.Raw))
|
||||
}
|
||||
}
|
||||
|
||||
var result []byte
|
||||
var err error
|
||||
|
||||
if len(filtered) == 0 {
|
||||
// Remove tools array entirely
|
||||
result, err = sjson.DeleteBytes(body, "tools")
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to delete tools: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Replace with filtered array
|
||||
filteredJSON, marshalErr := json.Marshal(filtered)
|
||||
if marshalErr != nil {
|
||||
return body, fmt.Errorf("failed to marshal filtered tools: %w", marshalErr)
|
||||
}
|
||||
result, err = sjson.SetRawBytes(body, "tools", filteredJSON)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to set filtered tools: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FormatSearchContextPrompt formats search results as a structured text block
|
||||
// for injection into the system prompt.
|
||||
func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
|
||||
@@ -365,7 +321,7 @@ func FormatToolResultText(results *WebSearchResults) string {
|
||||
//
|
||||
// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
|
||||
// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
|
||||
// Use ReplaceWebSearchToolDescription (not StripWebSearchTool) to keep the tool available.
|
||||
// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
@@ -512,658 +468,28 @@ type SearchIndicator struct {
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// SSE Event Generation
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// SseEvent represents a Server-Sent Event
|
||||
type SseEvent struct {
|
||||
Event string
|
||||
Data interface{}
|
||||
// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
|
||||
// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
// ToSSEString converts the event to SSE wire format
|
||||
func (e *SseEvent) ToSSEString() string {
|
||||
dataBytes, _ := json.Marshal(e.Data)
|
||||
return fmt.Sprintf("event: %s\ndata: %s\n\n", e.Event, string(dataBytes))
|
||||
}
|
||||
|
||||
// GenerateWebSearchEvents generates the 11-event SSE sequence for web search.
|
||||
// Events: message_start, content_block_start(server_tool_use), content_block_delta(input_json),
|
||||
// content_block_stop, content_block_start(web_search_tool_result), content_block_stop,
|
||||
// content_block_start(text), content_block_delta(text), content_block_stop, message_delta, message_stop
|
||||
func GenerateWebSearchEvents(
|
||||
model string,
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
inputTokens int,
|
||||
) []SseEvent {
|
||||
events := make([]SseEvent, 0, 15)
|
||||
messageID := GenerateMessageID()
|
||||
|
||||
// 1. message_start
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": messageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_start (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_stop (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 6. content_block_stop (web_search_tool_result)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 1,
|
||||
},
|
||||
})
|
||||
|
||||
// 7. content_block_start (text)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 2,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 8. content_block_delta (text_delta) - generate search summary
|
||||
summary := generateSearchSummary(query, searchResults)
|
||||
|
||||
// Split text into chunks for streaming effect
|
||||
chunkSize := 100
|
||||
runes := []rune(summary)
|
||||
for i := 0; i < len(runes); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunk := string(runes[i:end])
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 2,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": chunk,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 9. content_block_stop (text)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 2,
|
||||
},
|
||||
})
|
||||
|
||||
// 10. message_delta
|
||||
outputTokens := (len(summary) + 3) / 4 // Simple estimation
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 11. message_stop
|
||||
events = append(events, SseEvent{
|
||||
Event: "message_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// generateSearchSummary generates a text summary of search results
|
||||
func generateSearchSummary(query string, results *WebSearchResults) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Here are the search results for \"%s\":\n\n", query))
|
||||
|
||||
if results != nil && len(results.Results) > 0 {
|
||||
for i, r := range results.Results {
|
||||
sb.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, r.Title))
|
||||
if r.Snippet != nil {
|
||||
snippet := *r.Snippet
|
||||
if len(snippet) > 200 {
|
||||
snippet = snippet[:200] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", snippet))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" Source: %s\n\n", r.URL))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("No results found.\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\nPlease note that these are web search results and may not be fully accurate or up-to-date.")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
|
||||
// (server_tool_use + web_search_tool_result) without text summary or message termination.
|
||||
// These events trigger Claude Code's search indicator UI.
|
||||
// The caller is responsible for sending message_start before and message_delta/stop after.
|
||||
func GenerateSearchIndicatorEvents(
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
startIndex int,
|
||||
) []SseEvent {
|
||||
events := make([]SseEvent, 0, 4)
|
||||
|
||||
// 1. content_block_start (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 2. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 3. content_block_stop (server_tool_use)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
},
|
||||
})
|
||||
|
||||
// 4. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// 5. content_block_stop (web_search_tool_result)
|
||||
events = append(events, SseEvent{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
},
|
||||
})
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Stream Analysis & Manipulation
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
|
||||
// It also suppresses duplicate message_start events (returns shouldForward=false).
|
||||
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
|
||||
//
|
||||
// The data parameter is a single SSE "data:" line payload (JSON).
|
||||
// Returns: adjusted data, shouldForward (false = skip this event).
|
||||
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
|
||||
if len(data) == 0 {
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Quick check: parse the JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
// Not valid JSON, pass through
|
||||
return data, true
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Suppress duplicate message_start events
|
||||
if eventType == "message_start" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// Adjust index for content_block events
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + offset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return data, true
|
||||
}
|
||||
return adjusted, true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
|
||||
return data, true
|
||||
}
|
||||
|
||||
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
|
||||
// and adjusts content block indices. Suppresses duplicate message_start events.
|
||||
// Returns the adjusted chunk and whether it should be forwarded.
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
chunkStr := string(chunk)
|
||||
|
||||
// Fast path: if no "data:" prefix, pass through
|
||||
if !strings.Contains(chunkStr, "data: ") {
|
||||
return chunk, true
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
hasContent := false
|
||||
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
|
||||
if !shouldForward {
|
||||
// Skip this event and its preceding "event:" line
|
||||
// Also skip the trailing empty line
|
||||
continue
|
||||
}
|
||||
|
||||
result.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
|
||||
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
|
||||
// Skip both the event: and data: lines
|
||||
i++ // skip the data: line too
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
result.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []byte(result.String()), true
|
||||
}
|
||||
|
||||
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
|
||||
type BufferedStreamResult struct {
|
||||
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
|
||||
StopReason string
|
||||
// WebSearchQuery is the extracted query if the model requested another web_search
|
||||
WebSearchQuery string
|
||||
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
|
||||
WebSearchToolUseId string
|
||||
// HasWebSearchToolUse indicates whether the model requested web_search
|
||||
HasWebSearchToolUse bool
|
||||
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
|
||||
// This is used in the search loop to determine if the model wants another search round.
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
|
||||
// Track tool use state across chunks
|
||||
var currentToolName string
|
||||
var currentToolIndex int = -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
if dataPayload == "[DONE]" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "message_delta":
|
||||
// Extract stop_reason from message_delta
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
|
||||
result.StopReason = sr
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
// Detect tool_use content blocks
|
||||
if cb, ok := event["content_block"].(map[string]interface{}); ok {
|
||||
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
|
||||
if name, ok := cb["name"].(string); ok {
|
||||
currentToolName = strings.ToLower(name)
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Accumulate tool input JSON
|
||||
if currentToolName != "" {
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize tool use detection
|
||||
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
// Extract query from accumulated input JSON
|
||||
inputJSON := toolInputBuilder.String()
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
|
||||
if q, ok := input["query"]; ok {
|
||||
result.WebSearchQuery = q
|
||||
}
|
||||
}
|
||||
log.Debugf("kiro/websearch: detected web_search tool_use, query: %s", result.WebSearchQuery)
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
|
||||
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
|
||||
// when the proxy is handling the search loop internally.
|
||||
// Also suppresses message_start and message_delta/message_stop events since those
|
||||
// are managed by the outer handleWebSearchStream.
|
||||
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
|
||||
var filtered [][]byte
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
// Skip [DONE] — the outer loop manages stream termination
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Skip message_start (outer loop sends its own)
|
||||
if eventType == "message_start" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip message_delta and message_stop (outer loop manages these)
|
||||
if eventType == "message_delta" || eventType == "message_stop" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this event belongs to the web_search tool_use block
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
// Skip events for the web_search tool_use block
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Apply index offset for remaining events
|
||||
if indexOffset > 0 {
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err == nil {
|
||||
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
nextData := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
nextData = strings.TrimSpace(nextData)
|
||||
|
||||
var nextEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
|
||||
nextType, _ := nextEvent["type"].(string)
|
||||
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
filtered = append(filtered, []byte(resultBuilder.String()))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
// Package claude provides web search handler for Kiro translator.
|
||||
// This file implements the MCP API call and response handling.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cached web_search tool description fetched from MCP tools/list.
|
||||
// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure:
|
||||
// - sync.Once prevents race conditions and deduplicates concurrent calls
|
||||
// - On failure, a fresh sync.Once is swapped in to allow retry on next call
|
||||
// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls
|
||||
var (
|
||||
cachedToolDescription atomic.Value // stores string
|
||||
toolDescOnce atomic.Pointer[sync.Once]
|
||||
fallbackFpOnce sync.Once
|
||||
fallbackFp *kiroauth.Fingerprint
|
||||
)
|
||||
|
||||
func init() {
|
||||
toolDescOnce.Store(&sync.Once{})
|
||||
}
|
||||
|
||||
// FetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) {
|
||||
toolDescOnce.Load().Do(func() {
|
||||
handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as CallMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
cachedToolDescription.Store(tool.Description)
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return // success — sync.Once stays "done", no more fetches
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
})
|
||||
}
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// WebSearchHandler handles web search requests via Kiro MCP API
|
||||
type WebSearchHandler struct {
|
||||
McpEndpoint string
|
||||
HTTPClient *http.Client
|
||||
AuthToken string
|
||||
Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers
|
||||
AuthAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// NewWebSearchHandler creates a new WebSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// If fingerprint is nil, a random one-off fingerprint is generated.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
if fp == nil {
|
||||
// Use a shared fallback fingerprint for callers without token context
|
||||
fallbackFpOnce.Do(func() {
|
||||
mgr := kiroauth.NewFingerprintManager()
|
||||
fallbackFp = mgr.GetFingerprint("mcp-fallback")
|
||||
})
|
||||
fp = fallbackFp
|
||||
}
|
||||
return &WebSearchHandler{
|
||||
McpEndpoint: mcpEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
AuthToken: authToken,
|
||||
Fingerprint: fp,
|
||||
AuthAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern in kiro_executor.go.
|
||||
func (h *WebSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
fp := h.Fingerprint
|
||||
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. Dynamic fingerprint headers
|
||||
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
||||
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
||||
|
||||
// 4. AWS SDK identifiers (casing aligned with GAR)
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.AuthToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// CallMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors,
|
||||
// aligned with the GAR request retry pattern.
|
||||
func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
Reference in New Issue
Block a user