mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-08 06:43:41 +00:00
feat: optimize connection pooling and improve Kiro executor reliability
## 中文说明 ### 连接池优化 - 为 AMP 代理、SOCKS5 代理和 HTTP 代理配置优化的连接池参数 - MaxIdleConnsPerHost 从默认的 2 增加到 20,支持更多并发用户 - MaxConnsPerHost 设为 0(无限制),避免连接瓶颈 - 添加 IdleConnTimeout (90s) 和其他超时配置 ### Kiro 执行器增强 - 添加 Event Stream 消息解析的边界保护,防止越界访问 - 实现实时使用量估算(每 5000 字符或 15 秒发送 ping 事件) - 正确从上游事件中提取并传递 stop_reason - 改进输入 token 计算,优先使用 Claude 格式解析 - 添加 max_tokens 截断警告日志 ### Token 计算改进 - 添加 tokenizer 缓存(sync.Map)避免重复创建 - 为 Claude/Kiro/AmazonQ 模型添加 1.1 调整因子 - 新增 countClaudeChatTokens 函数支持 Claude API 格式 - 支持图像 token 估算(基于尺寸计算) ### 认证刷新优化 - RefreshLead 从 30 分钟改为 5 分钟,与 Antigravity 保持一致 - 修复 NextRefreshAfter 设置,防止频繁刷新检查 - refreshFailureBackoff 从 5 分钟改为 1 分钟,加快失败恢复 --- ## English Description ### Connection Pool Optimization - Configure optimized connection pool parameters for AMP proxy, SOCKS5 proxy, and HTTP proxy - Increase MaxIdleConnsPerHost from default 2 to 20 to support more concurrent users - Set MaxConnsPerHost to 0 (unlimited) to avoid connection bottlenecks - Add IdleConnTimeout (90s) and other timeout configurations ### Kiro Executor Enhancements - Add boundary protection for Event Stream message parsing to prevent out-of-bounds access - Implement real-time usage estimation (send ping events every 5000 chars or 15 seconds) - Correctly extract and pass stop_reason from upstream events - Improve input token calculation, prioritize Claude format parsing - Add max_tokens truncation warning logs ### Token Calculation Improvements - Add tokenizer cache (sync.Map) to avoid repeated creation - Add 1.1 adjustment factor for Claude/Kiro/AmazonQ models - Add countClaudeChatTokens function to support Claude API format - Support image token estimation (calculated based on dimensions) ### Authentication Refresh Optimization - Change RefreshLead from 30 minutes to 5 minutes, consistent with Antigravity - Fix NextRefreshAfter setting to prevent frequent refresh checks - Change refreshFailureBackoff from 5 minutes to 1 minute for faster failure recovery
This commit is contained in:
@@ -7,11 +7,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -36,6 +38,22 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(parsed)
|
||||
|
||||
// Configure custom Transport with optimized connection pooling for high concurrency
|
||||
proxy.Transport = &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
|
||||
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
originalDirector := proxy.Director
|
||||
|
||||
// Modify outgoing requests to inject API key and fix routing
|
||||
@@ -64,7 +82,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Modify incoming responses to handle gzip without Content-Encoding
|
||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Only process successful responses
|
||||
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||
if resp.StatusCode >= 500 {
|
||||
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
} else if resp.StatusCode >= 400 {
|
||||
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||
}
|
||||
|
||||
// Only process successful responses for gzip decompression
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil
|
||||
}
|
||||
@@ -148,15 +174,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error handler for proxy failures
|
||||
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
// Check if this is a client-side cancellation (normal behavior)
|
||||
// Classify the error type for better diagnostics
|
||||
var errType string
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
errType = "timeout"
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
errType = "canceled"
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
errType = "dial_timeout"
|
||||
} else if _, ok := err.(net.Error); ok {
|
||||
errType = "network_error"
|
||||
} else {
|
||||
errType = "connection_error"
|
||||
}
|
||||
|
||||
// Don't log as error for context canceled - it's usually client closing connection
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path)
|
||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||
} else {
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||
|
||||
@@ -36,6 +36,16 @@ const (
|
||||
kiroAcceptStream = "*/*"
|
||||
kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream
|
||||
kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..."
|
||||
|
||||
// Event Stream frame size constants for boundary protection
|
||||
// AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes)
|
||||
// Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4)
|
||||
minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc)
|
||||
maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB
|
||||
|
||||
// Event Stream error type constants
|
||||
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
||||
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
||||
// kiroUserAgent matches amq2api format for User-Agent header
|
||||
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
||||
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api
|
||||
@@ -102,6 +112,13 @@ You MUST follow these rules for ALL file operations. Violation causes server tim
|
||||
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
|
||||
)
|
||||
|
||||
// Real-time usage estimation configuration
|
||||
// These control how often usage updates are sent during streaming
|
||||
var (
|
||||
usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters
|
||||
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
|
||||
)
|
||||
|
||||
// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values.
|
||||
// This solves the "triple mismatch" problem where different endpoints require matching
|
||||
// Origin and X-Amz-Target header values.
|
||||
@@ -495,7 +512,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
}()
|
||||
|
||||
content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body)
|
||||
content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
@@ -503,14 +520,14 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Fallback for usage if missing from upstream
|
||||
if usageInfo.TotalTokens == 0 {
|
||||
if enc, encErr := tokenizerForModel(req.Model); encErr == nil {
|
||||
if enc, encErr := getTokenizer(req.Model); encErr == nil {
|
||||
if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil {
|
||||
usageInfo.InputTokens = inp
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
// Use tiktoken for more accurate output token calculation
|
||||
if enc, encErr := tokenizerForModel(req.Model); encErr == nil {
|
||||
if enc, encErr := getTokenizer(req.Model); encErr == nil {
|
||||
if tokenCount, countErr := enc.Count(content); countErr == nil {
|
||||
usageInfo.OutputTokens = int64(tokenCount)
|
||||
}
|
||||
@@ -530,7 +547,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
reporter.publish(ctx, usageInfo)
|
||||
|
||||
// Build response in Claude format for Kiro translator
|
||||
kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo)
|
||||
// stopReason is extracted from upstream response by parseEventStream
|
||||
kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
@@ -970,11 +988,40 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
return "claude-sonnet-4.5"
|
||||
}
|
||||
|
||||
// EventStreamError represents an Event Stream processing error
|
||||
type EventStreamError struct {
|
||||
Type string // "fatal", "malformed"
|
||||
Message string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *EventStreamError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("event stream %s: %s", e.Type, e.Message)
|
||||
}
|
||||
|
||||
// eventStreamMessage represents a parsed AWS Event Stream message
|
||||
type eventStreamMessage struct {
|
||||
EventType string // Event type from headers (e.g., "assistantResponseEvent")
|
||||
Payload []byte // JSON payload of the message
|
||||
}
|
||||
|
||||
// Kiro API request structs - field order determines JSON key order
|
||||
|
||||
type kiroPayload struct {
|
||||
ConversationState kiroConversationState `json:"conversationState"`
|
||||
ProfileArn string `json:"profileArn,omitempty"`
|
||||
InferenceConfig *kiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
}
|
||||
|
||||
// kiroInferenceConfig contains inference parameters for the Kiro API.
|
||||
// NOTE: This is an experimental addition - Kiro/Amazon Q API may not support these parameters.
|
||||
// If the API ignores or rejects these fields, response length is controlled internally by the model.
|
||||
type kiroInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"` // Maximum output tokens (may be ignored by API)
|
||||
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (may be ignored by API)
|
||||
}
|
||||
|
||||
type kiroConversationState struct {
|
||||
@@ -1058,7 +1105,25 @@ type kiroToolUse struct {
|
||||
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||
// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint.
|
||||
//
|
||||
// max_tokens support: Kiro/Amazon Q API may not officially support max_tokens parameter.
|
||||
// We attempt to pass it via inferenceConfig.maxTokens, but the API may ignore it.
|
||||
// Response truncation can be detected via stop_reason == "max_tokens" in the response.
|
||||
func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte {
|
||||
// Extract max_tokens for potential use in inferenceConfig
|
||||
var maxTokens int64
|
||||
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
|
||||
maxTokens = mt.Int()
|
||||
}
|
||||
|
||||
// Extract temperature if specified
|
||||
var temperature float64
|
||||
var hasTemperature bool
|
||||
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
|
||||
temperature = temp.Float()
|
||||
hasTemperature = true
|
||||
}
|
||||
|
||||
// Normalize origin value for Kiro API compatibility
|
||||
// Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values
|
||||
switch origin {
|
||||
@@ -1325,6 +1390,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn,
|
||||
}}
|
||||
}
|
||||
|
||||
// Build inferenceConfig if we have any inference parameters
|
||||
var inferenceConfig *kiroInferenceConfig
|
||||
if maxTokens > 0 || hasTemperature {
|
||||
inferenceConfig = &kiroInferenceConfig{}
|
||||
if maxTokens > 0 {
|
||||
inferenceConfig.MaxTokens = int(maxTokens)
|
||||
}
|
||||
if hasTemperature {
|
||||
inferenceConfig.Temperature = temperature
|
||||
}
|
||||
}
|
||||
|
||||
// Build payload with correct field order (matches struct definition)
|
||||
payload := kiroPayload{
|
||||
ConversationState: kiroConversationState{
|
||||
@@ -1333,7 +1410,8 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn,
|
||||
CurrentMessage: currentMessage,
|
||||
History: history, // Now always included (non-nil slice)
|
||||
},
|
||||
ProfileArn: profileArn,
|
||||
ProfileArn: profileArn,
|
||||
InferenceConfig: inferenceConfig,
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
@@ -1493,12 +1571,14 @@ func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssista
|
||||
// NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults
|
||||
|
||||
// parseEventStream parses AWS Event Stream binary format.
|
||||
// Extracts text content and tool uses from the response.
|
||||
// Extracts text content, tool uses, and stop_reason from the response.
|
||||
// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent.
|
||||
func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) {
|
||||
// Returns: content, toolUses, usageInfo, stopReason, error
|
||||
func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, string, error) {
|
||||
var content strings.Builder
|
||||
var toolUses []kiroToolUse
|
||||
var usageInfo usage.Detail
|
||||
var stopReason string // Extracted from upstream response
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
// Tool use state tracking for input buffering and deduplication
|
||||
@@ -1506,59 +1586,28 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
var currentToolUse *toolUseState
|
||||
|
||||
for {
|
||||
prelude := make([]byte, 8)
|
||||
_, err := io.ReadFull(reader, prelude)
|
||||
if err == io.EOF {
|
||||
msg, eventErr := e.readEventStreamMessage(reader)
|
||||
if eventErr != nil {
|
||||
log.Errorf("kiro: parseEventStream error: %v", eventErr)
|
||||
return content.String(), toolUses, usageInfo, stopReason, eventErr
|
||||
}
|
||||
if msg == nil {
|
||||
// Normal end of stream (EOF)
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err)
|
||||
}
|
||||
|
||||
totalLen := binary.BigEndian.Uint32(prelude[0:4])
|
||||
if totalLen < 8 {
|
||||
return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen)
|
||||
}
|
||||
if totalLen > kiroMaxMessageSize {
|
||||
return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen)
|
||||
}
|
||||
headersLen := binary.BigEndian.Uint32(prelude[4:8])
|
||||
|
||||
remaining := make([]byte, totalLen-8)
|
||||
_, err = io.ReadFull(reader, remaining)
|
||||
if err != nil {
|
||||
return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err)
|
||||
}
|
||||
|
||||
// Validate headersLen to prevent slice out of bounds
|
||||
if headersLen+4 > uint32(len(remaining)) {
|
||||
log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining))
|
||||
eventType := msg.EventType
|
||||
payload := msg.Payload
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract event type from headers
|
||||
eventType := e.extractEventType(remaining[:headersLen+4])
|
||||
|
||||
payloadStart := 4 + headersLen
|
||||
payloadEnd := uint32(len(remaining)) - 4
|
||||
if payloadStart >= payloadEnd {
|
||||
continue
|
||||
}
|
||||
|
||||
payload := remaining[payloadStart:payloadEnd]
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &event); err != nil {
|
||||
log.Debugf("kiro: skipping malformed event: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// DIAGNOSTIC: Log all received event types for debugging
|
||||
log.Debugf("kiro: parseEventStream received event type: %s", eventType)
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("kiro: parseEventStream event payload: %s", string(payload))
|
||||
}
|
||||
|
||||
// Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
|
||||
// These can appear as top-level fields or nested within the event
|
||||
if errType, hasErrType := event["_type"].(string); hasErrType {
|
||||
@@ -1568,7 +1617,7 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
errMsg = msg
|
||||
}
|
||||
log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg)
|
||||
return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg)
|
||||
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg)
|
||||
}
|
||||
if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") {
|
||||
// Generic error event
|
||||
@@ -1581,7 +1630,18 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
}
|
||||
}
|
||||
log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg)
|
||||
return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg)
|
||||
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg)
|
||||
}
|
||||
|
||||
// Extract stop_reason from various event formats
|
||||
// Kiro/Amazon Q API may include stop_reason in different locations
|
||||
if sr := getString(event, "stop_reason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason)
|
||||
}
|
||||
if sr := getString(event, "stopReason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason)
|
||||
}
|
||||
|
||||
// Handle different event types
|
||||
@@ -1596,6 +1656,15 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
if contentText, ok := assistantResp["content"].(string); ok {
|
||||
content.WriteString(contentText)
|
||||
}
|
||||
// Extract stop_reason from assistantResponseEvent
|
||||
if sr := getString(assistantResp, "stop_reason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason)
|
||||
}
|
||||
if sr := getString(assistantResp, "stopReason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason)
|
||||
}
|
||||
// Extract tool uses from response
|
||||
if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok {
|
||||
for _, tuRaw := range toolUsesRaw {
|
||||
@@ -1661,6 +1730,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
if outputTokens, ok := event["outputTokens"].(float64); ok {
|
||||
usageInfo.OutputTokens = int64(outputTokens)
|
||||
}
|
||||
|
||||
case "messageStopEvent", "message_stop":
|
||||
// Handle message stop events which may contain stop_reason
|
||||
if sr := getString(event, "stop_reason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason)
|
||||
}
|
||||
if sr := getString(event, "stopReason"); sr != "" {
|
||||
stopReason = sr
|
||||
log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason)
|
||||
}
|
||||
}
|
||||
|
||||
// Also check nested supplementaryWebLinksEvent
|
||||
@@ -1682,10 +1762,166 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
|
||||
// Deduplicate all tool uses
|
||||
toolUses = deduplicateToolUses(toolUses)
|
||||
|
||||
return cleanedContent, toolUses, usageInfo, nil
|
||||
// Apply fallback logic for stop_reason if not provided by upstream
|
||||
// Priority: upstream stopReason > tool_use detection > end_turn default
|
||||
if stopReason == "" {
|
||||
if len(toolUses) > 0 {
|
||||
stopReason = "tool_use"
|
||||
log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses))
|
||||
} else {
|
||||
stopReason = "end_turn"
|
||||
log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn")
|
||||
}
|
||||
}
|
||||
|
||||
// Log warning if response was truncated due to max_tokens
|
||||
if stopReason == "max_tokens" {
|
||||
log.Warnf("kiro: response truncated due to max_tokens limit")
|
||||
}
|
||||
|
||||
return cleanedContent, toolUses, usageInfo, stopReason, nil
|
||||
}
|
||||
|
||||
// readEventStreamMessage reads and validates a single AWS Event Stream message.
|
||||
// Returns the parsed message or a structured error for different failure modes.
|
||||
// This function implements boundary protection and detailed error classification.
|
||||
//
|
||||
// AWS Event Stream binary format:
|
||||
// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4)
|
||||
// - Headers (variable): header entries
|
||||
// - Payload (variable): JSON data
|
||||
// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped)
|
||||
func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) {
|
||||
// Read prelude (first 12 bytes: total_len + headers_len + prelude_crc)
|
||||
prelude := make([]byte, 12)
|
||||
_, err := io.ReadFull(reader, prelude)
|
||||
if err == io.EOF {
|
||||
return nil, nil // Normal end of stream
|
||||
}
|
||||
if err != nil {
|
||||
return nil, &EventStreamError{
|
||||
Type: ErrStreamFatal,
|
||||
Message: "failed to read prelude",
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
totalLength := binary.BigEndian.Uint32(prelude[0:4])
|
||||
headersLength := binary.BigEndian.Uint32(prelude[4:8])
|
||||
// Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements)
|
||||
|
||||
// Boundary check: minimum frame size
|
||||
if totalLength < minEventStreamFrameSize {
|
||||
return nil, &EventStreamError{
|
||||
Type: ErrStreamMalformed,
|
||||
Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Boundary check: maximum message size
|
||||
if totalLength > maxEventStreamMsgSize {
|
||||
return nil, &EventStreamError{
|
||||
Type: ErrStreamMalformed,
|
||||
Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Boundary check: headers length within message bounds
|
||||
// Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4)
|
||||
// So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc)
|
||||
if headersLength > totalLength-16 {
|
||||
return nil, &EventStreamError{
|
||||
Type: ErrStreamMalformed,
|
||||
Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Read the rest of the message (total - 12 bytes already read)
|
||||
remaining := make([]byte, totalLength-12)
|
||||
_, err = io.ReadFull(reader, remaining)
|
||||
if err != nil {
|
||||
return nil, &EventStreamError{
|
||||
Type: ErrStreamFatal,
|
||||
Message: "failed to read message body",
|
||||
Cause: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract event type from headers
|
||||
// Headers start at beginning of 'remaining', length is headersLength
|
||||
var eventType string
|
||||
if headersLength > 0 && headersLength <= uint32(len(remaining)) {
|
||||
eventType = e.extractEventTypeFromBytes(remaining[:headersLength])
|
||||
}
|
||||
|
||||
// Calculate payload boundaries
|
||||
// Payload starts after headers, ends before message_crc (last 4 bytes)
|
||||
payloadStart := headersLength
|
||||
payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end
|
||||
|
||||
// Validate payload boundaries
|
||||
if payloadStart >= payloadEnd {
|
||||
// No payload, return empty message
|
||||
return &eventStreamMessage{
|
||||
EventType: eventType,
|
||||
Payload: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
payload := remaining[payloadStart:payloadEnd]
|
||||
|
||||
return &eventStreamMessage{
|
||||
EventType: eventType,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix)
|
||||
func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string {
|
||||
offset := 0
|
||||
for offset < len(headers) {
|
||||
if offset >= len(headers) {
|
||||
break
|
||||
}
|
||||
nameLen := int(headers[offset])
|
||||
offset++
|
||||
if offset+nameLen > len(headers) {
|
||||
break
|
||||
}
|
||||
name := string(headers[offset : offset+nameLen])
|
||||
offset += nameLen
|
||||
|
||||
if offset >= len(headers) {
|
||||
break
|
||||
}
|
||||
valueType := headers[offset]
|
||||
offset++
|
||||
|
||||
if valueType == 7 { // String type
|
||||
if offset+2 > len(headers) {
|
||||
break
|
||||
}
|
||||
valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
|
||||
offset += 2
|
||||
if offset+valueLen > len(headers) {
|
||||
break
|
||||
}
|
||||
value := string(headers[offset : offset+valueLen])
|
||||
offset += valueLen
|
||||
|
||||
if name == ":event-type" {
|
||||
return value
|
||||
}
|
||||
} else {
|
||||
// Skip other types
|
||||
break
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractEventType extracts the event type from AWS Event Stream headers
|
||||
// Note: This is the legacy version that expects headerBytes to include prelude CRC prefix
|
||||
func (e *KiroExecutor) extractEventType(headerBytes []byte) string {
|
||||
// Skip prelude CRC (4 bytes)
|
||||
if len(headerBytes) < 4 {
|
||||
@@ -1746,7 +1982,8 @@ func getString(m map[string]interface{}, key string) string {
|
||||
// buildClaudeResponse constructs a Claude-compatible response.
|
||||
// Supports tool_use blocks when tools are present in the response.
|
||||
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
|
||||
func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte {
|
||||
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||
func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||
var contentBlocks []map[string]interface{}
|
||||
|
||||
// Extract thinking blocks and text from content
|
||||
@@ -1782,10 +2019,18 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs
|
||||
})
|
||||
}
|
||||
|
||||
// Determine stop reason
|
||||
stopReason := "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
stopReason = "tool_use"
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
if stopReason == "" {
|
||||
stopReason = "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
|
||||
}
|
||||
|
||||
// Log warning if response was truncated due to max_tokens
|
||||
if stopReason == "max_tokens" {
|
||||
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
@@ -1906,10 +2151,12 @@ func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]i
|
||||
// Supports tool calling - emits tool_use content blocks when tools are used.
|
||||
// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent.
|
||||
// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API).
|
||||
// Extracts stop_reason from upstream events when available.
|
||||
func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) {
|
||||
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
|
||||
var totalUsage usage.Detail
|
||||
var hasToolUses bool // Track if any tool uses were emitted
|
||||
var hasToolUses bool // Track if any tool uses were emitted
|
||||
var upstreamStopReason string // Track stop_reason from upstream events
|
||||
|
||||
// Tool use state tracking for input buffering and deduplication
|
||||
processedIDs := make(map[string]bool)
|
||||
@@ -1925,6 +2172,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
var accumulatedContent strings.Builder
|
||||
accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations
|
||||
|
||||
// Real-time usage estimation state
|
||||
// These track when to send periodic usage updates during streaming
|
||||
var lastUsageUpdateLen int // Last accumulated content length when usage was sent
|
||||
var lastUsageUpdateTime = time.Now() // Last time usage update was sent
|
||||
var lastReportedOutputTokens int64 // Last reported output token count
|
||||
|
||||
// Translator param for maintaining tool call state across streaming events
|
||||
// IMPORTANT: This must persist across all TranslateStream calls
|
||||
var translatorParam any
|
||||
@@ -1932,24 +2185,37 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
// Thinking mode state tracking - based on amq2api implementation
|
||||
// Tracks whether we're inside a <thinking> block and handles partial tags
|
||||
inThinkBlock := false
|
||||
pendingStartTagChars := 0 // Number of chars that might be start of <thinking>
|
||||
pendingEndTagChars := 0 // Number of chars that might be start of </thinking>
|
||||
isThinkingBlockOpen := false // Track if thinking content block is open
|
||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||
pendingStartTagChars := 0 // Number of chars that might be start of <thinking>
|
||||
pendingEndTagChars := 0 // Number of chars that might be start of </thinking>
|
||||
isThinkingBlockOpen := false // Track if thinking content block is open
|
||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||
|
||||
// Pre-calculate input tokens from request if possible
|
||||
if enc, err := tokenizerForModel(model); err == nil {
|
||||
// Try OpenAI format first, then fall back to raw byte count estimation
|
||||
if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
// Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback
|
||||
if enc, err := getTokenizer(model); err == nil {
|
||||
var inputTokens int64
|
||||
var countMethod string
|
||||
|
||||
// Try Claude format first (Kiro uses Claude API format)
|
||||
if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 {
|
||||
inputTokens = inp
|
||||
countMethod = "claude"
|
||||
} else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
|
||||
// Fallback to OpenAI format (for OpenAI-compatible requests)
|
||||
inputTokens = inp
|
||||
countMethod = "openai"
|
||||
} else {
|
||||
// Fallback: estimate from raw request size (roughly 4 chars per token)
|
||||
totalUsage.InputTokens = int64(len(originalReq) / 4)
|
||||
if totalUsage.InputTokens == 0 && len(originalReq) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
// Final fallback: estimate from raw request size (roughly 4 chars per token)
|
||||
inputTokens = int64(len(claudeBody) / 4)
|
||||
if inputTokens == 0 && len(claudeBody) > 0 {
|
||||
inputTokens = 1
|
||||
}
|
||||
countMethod = "estimate"
|
||||
}
|
||||
log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq))
|
||||
|
||||
totalUsage.InputTokens = inputTokens
|
||||
log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)",
|
||||
totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq))
|
||||
}
|
||||
|
||||
contentBlockIndex := -1
|
||||
@@ -1969,9 +2235,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
default:
|
||||
}
|
||||
|
||||
prelude := make([]byte, 8)
|
||||
_, err := io.ReadFull(reader, prelude)
|
||||
if err == io.EOF {
|
||||
msg, eventErr := e.readEventStreamMessage(reader)
|
||||
if eventErr != nil {
|
||||
// Log the error
|
||||
log.Errorf("kiro: streamToChannel error: %v", eventErr)
|
||||
|
||||
// Send error to channel for client notification
|
||||
out <- cliproxyexecutor.StreamChunk{Err: eventErr}
|
||||
return
|
||||
}
|
||||
if msg == nil {
|
||||
// Normal end of stream (EOF)
|
||||
// Flush any incomplete tool use before ending stream
|
||||
if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] {
|
||||
log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID)
|
||||
@@ -2069,44 +2343,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)}
|
||||
return
|
||||
}
|
||||
|
||||
totalLen := binary.BigEndian.Uint32(prelude[0:4])
|
||||
if totalLen < 8 {
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)}
|
||||
return
|
||||
}
|
||||
if totalLen > kiroMaxMessageSize {
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)}
|
||||
return
|
||||
}
|
||||
headersLen := binary.BigEndian.Uint32(prelude[4:8])
|
||||
|
||||
remaining := make([]byte, totalLen-8)
|
||||
_, err = io.ReadFull(reader, remaining)
|
||||
if err != nil {
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)}
|
||||
return
|
||||
}
|
||||
|
||||
// Validate headersLen to prevent slice out of bounds
|
||||
if headersLen+4 > uint32(len(remaining)) {
|
||||
log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining))
|
||||
eventType := msg.EventType
|
||||
payload := msg.Payload
|
||||
if len(payload) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType := e.extractEventType(remaining[:headersLen+4])
|
||||
|
||||
payloadStart := 4 + headersLen
|
||||
payloadEnd := uint32(len(remaining)) - 4
|
||||
if payloadStart >= payloadEnd {
|
||||
continue
|
||||
}
|
||||
|
||||
payload := remaining[payloadStart:payloadEnd]
|
||||
appendAPIResponseChunk(ctx, e.cfg, payload)
|
||||
|
||||
var event map[string]interface{}
|
||||
@@ -2115,12 +2357,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
continue
|
||||
}
|
||||
|
||||
// DIAGNOSTIC: Log all received event types for debugging
|
||||
log.Debugf("kiro: streamToChannel received event type: %s", eventType)
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("kiro: streamToChannel event payload: %s", string(payload))
|
||||
}
|
||||
|
||||
// Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
|
||||
// These can appear as top-level fields or nested within the event
|
||||
if errType, hasErrType := event["_type"].(string); hasErrType {
|
||||
@@ -2148,6 +2384,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
return
|
||||
}
|
||||
|
||||
// Extract stop_reason from various event formats (streaming)
|
||||
// Kiro/Amazon Q API may include stop_reason in different locations
|
||||
if sr := getString(event, "stop_reason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason)
|
||||
}
|
||||
if sr := getString(event, "stopReason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason)
|
||||
}
|
||||
|
||||
// Send message_start on first event
|
||||
if !messageStartSent {
|
||||
msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens)
|
||||
@@ -2166,6 +2413,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
log.Debugf("kiro: streamToChannel ignoring followupPrompt event")
|
||||
continue
|
||||
|
||||
case "messageStopEvent", "message_stop":
|
||||
// Handle message stop events which may contain stop_reason
|
||||
if sr := getString(event, "stop_reason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason)
|
||||
}
|
||||
if sr := getString(event, "stopReason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason)
|
||||
}
|
||||
|
||||
case "assistantResponseEvent":
|
||||
var contentDelta string
|
||||
var toolUses []map[string]interface{}
|
||||
@@ -2174,6 +2432,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
if c, ok := assistantResp["content"].(string); ok {
|
||||
contentDelta = c
|
||||
}
|
||||
// Extract stop_reason from assistantResponseEvent
|
||||
if sr := getString(assistantResp, "stop_reason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason)
|
||||
}
|
||||
if sr := getString(assistantResp, "stopReason"); sr != "" {
|
||||
upstreamStopReason = sr
|
||||
log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason)
|
||||
}
|
||||
// Extract tool uses from response
|
||||
if tus, ok := assistantResp["toolUses"].([]interface{}); ok {
|
||||
for _, tuRaw := range tus {
|
||||
@@ -2199,11 +2466,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
// Handle text content with thinking mode support
|
||||
if contentDelta != "" {
|
||||
// DIAGNOSTIC: Check for thinking tags in response
|
||||
if strings.Contains(contentDelta, "<thinking>") || strings.Contains(contentDelta, "</thinking>") {
|
||||
log.Infof("kiro: DIAGNOSTIC - Found thinking tag in response (len: %d)", len(contentDelta))
|
||||
}
|
||||
|
||||
// NOTE: Duplicate content filtering was removed because it incorrectly
|
||||
// filtered out legitimate repeated content (like consecutive newlines "\n\n").
|
||||
// Streaming naturally can have identical chunks that are valid content.
|
||||
@@ -2211,6 +2473,52 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
outputLen += len(contentDelta)
|
||||
// Accumulate content for streaming token calculation
|
||||
accumulatedContent.WriteString(contentDelta)
|
||||
|
||||
// Real-time usage estimation: Check if we should send a usage update
|
||||
// This helps clients track context usage during long thinking sessions
|
||||
shouldSendUsageUpdate := false
|
||||
if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold {
|
||||
shouldSendUsageUpdate = true
|
||||
} else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen {
|
||||
shouldSendUsageUpdate = true
|
||||
}
|
||||
|
||||
if shouldSendUsageUpdate {
|
||||
// Calculate current output tokens using tiktoken
|
||||
var currentOutputTokens int64
|
||||
if enc, encErr := getTokenizer(model); encErr == nil {
|
||||
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
|
||||
currentOutputTokens = int64(tokenCount)
|
||||
}
|
||||
}
|
||||
// Fallback to character estimation if tiktoken fails
|
||||
if currentOutputTokens == 0 {
|
||||
currentOutputTokens = int64(accumulatedContent.Len() / 4)
|
||||
if currentOutputTokens == 0 {
|
||||
currentOutputTokens = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Only send update if token count has changed significantly (at least 10 tokens)
|
||||
if currentOutputTokens > lastReportedOutputTokens+10 {
|
||||
// Send ping event with usage information
|
||||
// This is a non-blocking update that clients can optionally process
|
||||
pingEvent := e.buildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||
}
|
||||
}
|
||||
|
||||
lastReportedOutputTokens = currentOutputTokens
|
||||
log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)",
|
||||
totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len())
|
||||
}
|
||||
|
||||
lastUsageUpdateLen = accumulatedContent.Len()
|
||||
lastUsageUpdateTime = time.Now()
|
||||
}
|
||||
|
||||
// Process content with thinking tag detection - based on amq2api implementation
|
||||
// This handles <thinking> and </thinking> tags that may span across chunks
|
||||
@@ -2577,10 +2885,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
|
||||
// Streaming token calculation - calculate output tokens from accumulated content
|
||||
// This provides more accurate token counting than simple character division
|
||||
// Only use local estimation if server didn't provide usage (server-side usage takes priority)
|
||||
if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 {
|
||||
// Try to use tiktoken for accurate counting
|
||||
if enc, err := tokenizerForModel(model); err == nil {
|
||||
if enc, err := getTokenizer(model); err == nil {
|
||||
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
|
||||
totalUsage.OutputTokens = int64(tokenCount)
|
||||
log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens)
|
||||
@@ -2609,10 +2917,21 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens
|
||||
|
||||
// Determine stop reason based on whether tool uses were emitted
|
||||
stopReason := "end_turn"
|
||||
if hasToolUses {
|
||||
stopReason = "tool_use"
|
||||
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
|
||||
stopReason := upstreamStopReason
|
||||
if stopReason == "" {
|
||||
if hasToolUses {
|
||||
stopReason = "tool_use"
|
||||
log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use")
|
||||
} else {
|
||||
stopReason = "end_turn"
|
||||
log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn")
|
||||
}
|
||||
}
|
||||
|
||||
// Log warning if response was truncated due to max_tokens
|
||||
if stopReason == "max_tokens" {
|
||||
log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)")
|
||||
}
|
||||
|
||||
// Send message_delta event
|
||||
@@ -2758,6 +3077,24 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte {
|
||||
return []byte("event: message_stop\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// buildClaudePingEventWithUsage creates a ping event with embedded usage information.
|
||||
// This is used for real-time usage estimation during streaming.
|
||||
// The usage field is a non-standard extension that clients can optionally process.
|
||||
// Clients that don't recognize the usage field will simply ignore it.
|
||||
func (e *KiroExecutor) buildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
|
||||
event := map[string]interface{}{
|
||||
"type": "ping",
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": inputTokens,
|
||||
"output_tokens": outputTokens,
|
||||
"total_tokens": inputTokens + outputTokens,
|
||||
"estimated": true, // Flag to indicate this is an estimate, not final
|
||||
},
|
||||
}
|
||||
result, _ := json.Marshal(event)
|
||||
return []byte("event: ping\ndata: " + string(result))
|
||||
}
|
||||
|
||||
// buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
|
||||
// This is used when streaming thinking content wrapped in <thinking> tags.
|
||||
func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
|
||||
@@ -2837,10 +3174,21 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
||||
// Also check if expires_at is now in the future with sufficient buffer
|
||||
if expiresAt, ok := auth.Metadata["expires_at"].(string); ok {
|
||||
if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil {
|
||||
// If token expires more than 2 minutes from now, it's still valid
|
||||
if time.Until(expTime) > 2*time.Minute {
|
||||
// If token expires more than 5 minutes from now, it's still valid
|
||||
if time.Until(expTime) > 5*time.Minute {
|
||||
log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime))
|
||||
return auth, nil
|
||||
// CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks
|
||||
// Without this, shouldRefresh() will return true again in 5 seconds
|
||||
updated := auth.Clone()
|
||||
// Set next refresh to 5 minutes before expiry, or at least 30 seconds from now
|
||||
nextRefresh := expTime.Add(-5 * time.Minute)
|
||||
minNextRefresh := time.Now().Add(30 * time.Second)
|
||||
if nextRefresh.Before(minNextRefresh) {
|
||||
nextRefresh = minNextRefresh
|
||||
}
|
||||
updated.NextRefreshAfter = nextRefresh
|
||||
log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh))
|
||||
return updated, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2924,9 +3272,9 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
|
||||
updated.Attributes["profile_arn"] = tokenData.ProfileArn
|
||||
}
|
||||
|
||||
// Set next refresh time to 30 minutes before expiry
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil {
|
||||
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
}
|
||||
|
||||
log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt)
|
||||
@@ -2943,7 +3291,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c
|
||||
var translatorParam any
|
||||
|
||||
// Pre-calculate input tokens from request if possible
|
||||
if enc, err := tokenizerForModel(model); err == nil {
|
||||
if enc, err := getTokenizer(model); err == nil {
|
||||
// Try OpenAI format first, then fall back to raw byte count estimation
|
||||
if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
|
||||
@@ -137,15 +137,25 @@ func buildProxyTransport(proxyURL string) *http.Transport {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return nil
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer
|
||||
// Set up a custom transport using the SOCKS5 dialer with optimized connection pooling
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
|
||||
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
|
||||
// Configure HTTP or HTTPS proxy with optimized connection pooling
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
|
||||
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
} else {
|
||||
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
return nil
|
||||
|
||||
@@ -2,43 +2,107 @@ package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||
type TokenizerWrapper struct {
|
||||
Codec tokenizer.Codec
|
||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||
}
|
||||
|
||||
// Count returns the token count with adjustment factor applied
|
||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getTokenizer returns a cached tokenizer for the given model.
|
||||
// This improves performance by avoiding repeated tokenizer creation.
|
||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||
// Check cache first
|
||||
if cached, ok := tokenizerCache.Load(model); ok {
|
||||
return cached.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// Cache miss, create new tokenizer
|
||||
wrapper, err := tokenizerForModel(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (use LoadOrStore to handle race conditions)
|
||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||
return actual.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
|
||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||
// because tiktoken may underestimate Claude's actual token count
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
var enc tokenizer.Codec
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||
// This handles Claude's message format with system, messages, and tools.
|
||||
// Image tokens are estimated based on image dimensions when available.
|
||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
// Collect tools
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||
|
||||
// extractImageTokens extracts image token estimates from placeholder text.
|
||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||
func extractImageTokens(text string) int {
|
||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||
total := 0
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||
total += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image)
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
|
||||
// Apply bounds
|
||||
if tokens < 85 {
|
||||
tokens = 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
tokens = 1590
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// collectClaudeSystem extracts text from Claude's system field.
|
||||
// System can be a string or an array of content blocks.
|
||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||
if !system.Exists() {
|
||||
return
|
||||
}
|
||||
if system.Type == gjson.String {
|
||||
addIfNotEmpty(segments, system.String())
|
||||
return
|
||||
}
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "text" || blockType == "" {
|
||||
addIfNotEmpty(segments, block.Get("text").String())
|
||||
}
|
||||
// Also handle plain string blocks
|
||||
if block.Type == gjson.String {
|
||||
addIfNotEmpty(segments, block.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeMessages extracts text from Claude's messages array.
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// collectClaudeContent extracts text from Claude's content field.
|
||||
// Content can be a string or an array of content blocks.
|
||||
// For images, estimates token count based on dimensions when available.
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
// Estimate image tokens based on dimensions if available
|
||||
source := part.Get("source")
|
||||
if source.Exists() {
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if width > 0 && height > 0 {
|
||||
tokens := estimateImageTokens(width, height)
|
||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||
} else {
|
||||
// No dimensions available, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
} else {
|
||||
// No source info, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
// For unknown types, try to extract any text content
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeTools extracts text from Claude's tools array.
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -36,15 +37,25 @@ func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
|
||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
return httpClient
|
||||
}
|
||||
// Set up a custom transport using the SOCKS5 dialer.
|
||||
// Set up a custom transport using the SOCKS5 dialer with optimized connection pooling
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
|
||||
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Configure HTTP or HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
// Configure HTTP or HTTPS proxy with optimized connection pooling
|
||||
transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
|
||||
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
}
|
||||
}
|
||||
// If a new transport was created, apply it to the HTTP client.
|
||||
|
||||
@@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 30 * time.Minute
|
||||
d := 5 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
@@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
"source": "aws-builder-id",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
@@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
@@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
refreshFailureBackoff = 1 * time.Minute
|
||||
quotaBackoffBase = time.Second
|
||||
quotaBackoffMax = 30 * time.Minute
|
||||
)
|
||||
@@ -1471,7 +1471,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
updated.Runtime = auth.Runtime
|
||||
}
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
// Preserve NextRefreshAfter set by the Authenticator
|
||||
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
|
||||
Reference in New Issue
Block a user