feat: optimize connection pooling and improve Kiro executor reliability

## 中文说明

### 连接池优化
- 为 AMP 代理、SOCKS5 代理和 HTTP 代理配置优化的连接池参数
- MaxIdleConnsPerHost 从默认的 2 增加到 20,支持更多并发用户
- MaxConnsPerHost 设为 0(无限制),避免连接瓶颈
- 添加 IdleConnTimeout (90s) 和其他超时配置

### Kiro 执行器增强
- 添加 Event Stream 消息解析的边界保护,防止越界访问
- 实现实时使用量估算(每 5000 字符或 15 秒发送 ping 事件)
- 正确从上游事件中提取并传递 stop_reason
- 改进输入 token 计算,优先使用 Claude 格式解析
- 添加 max_tokens 截断警告日志

### Token 计算改进
- 添加 tokenizer 缓存(sync.Map)避免重复创建
- 为 Claude/Kiro/AmazonQ 模型添加 1.1 调整因子
- 新增 countClaudeChatTokens 函数支持 Claude API 格式
- 支持图像 token 估算(基于尺寸计算)

### 认证刷新优化
- RefreshLead 从 30 分钟改为 5 分钟,与 Antigravity 保持一致
- 修复 NextRefreshAfter 设置,防止频繁刷新检查
- refreshFailureBackoff 从 5 分钟改为 1 分钟,加快失败恢复

---

## English Description

### Connection Pool Optimization
- Configure optimized connection pool parameters for AMP proxy, SOCKS5 proxy, and HTTP proxy
- Increase MaxIdleConnsPerHost from default 2 to 20 to support more concurrent users
- Set MaxConnsPerHost to 0 (unlimited) to avoid connection bottlenecks
- Add IdleConnTimeout (90s) and other timeout configurations

### Kiro Executor Enhancements
- Add boundary protection for Event Stream message parsing to prevent out-of-bounds access
- Implement real-time usage estimation (send ping events every 5000 chars or 15 seconds)
- Correctly extract and pass stop_reason from upstream events
- Improve input token calculation, prioritize Claude format parsing
- Add max_tokens truncation warning logs

### Token Calculation Improvements
- Add tokenizer cache (sync.Map) to avoid repeated creation
- Add 1.1 adjustment factor for Claude/Kiro/AmazonQ models
- Add countClaudeChatTokens function to support Claude API format
- Support image token estimation (calculated based on dimensions)

### Authentication Refresh Optimization
- Change RefreshLead from 30 minutes to 5 minutes, consistent with Antigravity
- Fix NextRefreshAfter setting to prevent frequent refresh checks
- Change refreshFailureBackoff from 5 minutes to 1 minute for faster failure recovery
This commit is contained in:
Ravens2121
2025-12-13 10:19:53 +08:00
parent db80b20bc2
commit 58866b21cb
7 changed files with 840 additions and 164 deletions

View File

@@ -7,11 +7,13 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
@@ -36,6 +38,22 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
}
proxy := httputil.NewSingleHostReverseProxy(parsed)
// Configure custom Transport with optimized connection pooling for high concurrency
proxy.Transport = &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
IdleConnTimeout: 90 * time.Second,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 60 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
originalDirector := proxy.Director
// Modify outgoing requests to inject API key and fix routing
@@ -64,7 +82,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Modify incoming responses to handle gzip without Content-Encoding
// This addresses the same issue as inline handler gzip handling, but at the proxy level
proxy.ModifyResponse = func(resp *http.Response) error {
// Only process successful responses
// Log upstream error responses for diagnostics (502, 503, etc.)
// These are NOT proxy connection errors - the upstream responded with an error status
if resp.StatusCode >= 500 {
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
} else if resp.StatusCode >= 400 {
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
}
// Only process successful responses for gzip decompression
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil
}
@@ -148,15 +174,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
return nil
}
// Error handler for proxy failures
// Error handler for proxy failures with detailed error classification for diagnostics
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
// Check if this is a client-side cancellation (normal behavior)
// Classify the error type for better diagnostics
var errType string
if errors.Is(err, context.DeadlineExceeded) {
errType = "timeout"
} else if errors.Is(err, context.Canceled) {
errType = "canceled"
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
errType = "dial_timeout"
} else if _, ok := err.(net.Error); ok {
errType = "network_error"
} else {
errType = "connection_error"
}
// Don't log as error for context canceled - it's usually client closing connection
if errors.Is(err, context.Canceled) {
log.Debugf("amp upstream proxy: client canceled request for %s %s", req.Method, req.URL.Path)
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
} else {
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
}
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))

View File

@@ -36,6 +36,16 @@ const (
kiroAcceptStream = "*/*"
kiroMaxMessageSize = 10 * 1024 * 1024 // 10MB max message size for event stream
kiroMaxToolDescLen = 10237 // Kiro API limit is 10240 bytes, leave room for "..."
// Event Stream frame size constants for boundary protection
// AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes)
// Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4)
minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc)
maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB
// Event Stream error type constants
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
// kiroUserAgent matches amq2api format for User-Agent header
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api
@@ -102,6 +112,13 @@ You MUST follow these rules for ALL file operations. Violation causes server tim
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
)
// Real-time usage estimation configuration
// These control how often usage updates are sent during streaming
var (
usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
)
// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values.
// This solves the "triple mismatch" problem where different endpoints require matching
// Origin and X-Amz-Target header values.
@@ -495,7 +512,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
}
}()
content, toolUses, usageInfo, err := e.parseEventStream(httpResp.Body)
content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
@@ -503,14 +520,14 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
// Fallback for usage if missing from upstream
if usageInfo.TotalTokens == 0 {
if enc, encErr := tokenizerForModel(req.Model); encErr == nil {
if enc, encErr := getTokenizer(req.Model); encErr == nil {
if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil {
usageInfo.InputTokens = inp
}
}
if len(content) > 0 {
// Use tiktoken for more accurate output token calculation
if enc, encErr := tokenizerForModel(req.Model); encErr == nil {
if enc, encErr := getTokenizer(req.Model); encErr == nil {
if tokenCount, countErr := enc.Count(content); countErr == nil {
usageInfo.OutputTokens = int64(tokenCount)
}
@@ -530,7 +547,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
reporter.publish(ctx, usageInfo)
// Build response in Claude format for Kiro translator
kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo)
// stopReason is extracted from upstream response by parseEventStream
kiroResponse := e.buildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
@@ -970,11 +988,40 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
return "claude-sonnet-4.5"
}
// EventStreamError represents an Event Stream processing error
type EventStreamError struct {
Type string // "fatal", "malformed"
Message string
Cause error
}
func (e *EventStreamError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause)
}
return fmt.Sprintf("event stream %s: %s", e.Type, e.Message)
}
// eventStreamMessage represents a parsed AWS Event Stream message
type eventStreamMessage struct {
EventType string // Event type from headers (e.g., "assistantResponseEvent")
Payload []byte // JSON payload of the message
}
// Kiro API request structs - field order determines JSON key order
type kiroPayload struct {
ConversationState kiroConversationState `json:"conversationState"`
ProfileArn string `json:"profileArn,omitempty"`
InferenceConfig *kiroInferenceConfig `json:"inferenceConfig,omitempty"`
}
// kiroInferenceConfig contains inference parameters for the Kiro API.
// NOTE: This is an experimental addition - Kiro/Amazon Q API may not support these parameters.
// If the API ignores or rejects these fields, response length is controlled internally by the model.
type kiroInferenceConfig struct {
MaxTokens int `json:"maxTokens,omitempty"` // Maximum output tokens (may be ignored by API)
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature (may be ignored by API)
}
type kiroConversationState struct {
@@ -1058,7 +1105,25 @@ type kiroToolUse struct {
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint.
//
// max_tokens support: Kiro/Amazon Q API may not officially support max_tokens parameter.
// We attempt to pass it via inferenceConfig.maxTokens, but the API may ignore it.
// Response truncation can be detected via stop_reason == "max_tokens" in the response.
func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) []byte {
// Extract max_tokens for potential use in inferenceConfig
var maxTokens int64
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
maxTokens = mt.Int()
}
// Extract temperature if specified
var temperature float64
var hasTemperature bool
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
temperature = temp.Float()
hasTemperature = true
}
// Normalize origin value for Kiro API compatibility
// Kiro API only accepts "CLI" or "AI_EDITOR" as valid origin values
switch origin {
@@ -1325,6 +1390,18 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn,
}}
}
// Build inferenceConfig if we have any inference parameters
var inferenceConfig *kiroInferenceConfig
if maxTokens > 0 || hasTemperature {
inferenceConfig = &kiroInferenceConfig{}
if maxTokens > 0 {
inferenceConfig.MaxTokens = int(maxTokens)
}
if hasTemperature {
inferenceConfig.Temperature = temperature
}
}
// Build payload with correct field order (matches struct definition)
payload := kiroPayload{
ConversationState: kiroConversationState{
@@ -1333,7 +1410,8 @@ func (e *KiroExecutor) buildKiroPayload(claudeBody []byte, modelID, profileArn,
CurrentMessage: currentMessage,
History: history, // Now always included (non-nil slice)
},
ProfileArn: profileArn,
ProfileArn: profileArn,
InferenceConfig: inferenceConfig,
}
result, err := json.Marshal(payload)
@@ -1493,12 +1571,14 @@ func (e *KiroExecutor) buildAssistantMessageStruct(msg gjson.Result) kiroAssista
// NOTE: Tool calling is now supported via userInputMessageContext.tools and toolResults
// parseEventStream parses AWS Event Stream binary format.
// Extracts text content and tool uses from the response.
// Extracts text content, tool uses, and stop_reason from the response.
// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent.
func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, error) {
// Returns: content, toolUses, usageInfo, stopReason, error
func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse, usage.Detail, string, error) {
var content strings.Builder
var toolUses []kiroToolUse
var usageInfo usage.Detail
var stopReason string // Extracted from upstream response
reader := bufio.NewReader(body)
// Tool use state tracking for input buffering and deduplication
@@ -1506,59 +1586,28 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
var currentToolUse *toolUseState
for {
prelude := make([]byte, 8)
_, err := io.ReadFull(reader, prelude)
if err == io.EOF {
msg, eventErr := e.readEventStreamMessage(reader)
if eventErr != nil {
log.Errorf("kiro: parseEventStream error: %v", eventErr)
return content.String(), toolUses, usageInfo, stopReason, eventErr
}
if msg == nil {
// Normal end of stream (EOF)
break
}
if err != nil {
return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read prelude: %w", err)
}
totalLen := binary.BigEndian.Uint32(prelude[0:4])
if totalLen < 8 {
return content.String(), toolUses, usageInfo, fmt.Errorf("invalid message length: %d", totalLen)
}
if totalLen > kiroMaxMessageSize {
return content.String(), toolUses, usageInfo, fmt.Errorf("message too large: %d bytes", totalLen)
}
headersLen := binary.BigEndian.Uint32(prelude[4:8])
remaining := make([]byte, totalLen-8)
_, err = io.ReadFull(reader, remaining)
if err != nil {
return content.String(), toolUses, usageInfo, fmt.Errorf("failed to read message: %w", err)
}
// Validate headersLen to prevent slice out of bounds
if headersLen+4 > uint32(len(remaining)) {
log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining))
eventType := msg.EventType
payload := msg.Payload
if len(payload) == 0 {
continue
}
// Extract event type from headers
eventType := e.extractEventType(remaining[:headersLen+4])
payloadStart := 4 + headersLen
payloadEnd := uint32(len(remaining)) - 4
if payloadStart >= payloadEnd {
continue
}
payload := remaining[payloadStart:payloadEnd]
var event map[string]interface{}
if err := json.Unmarshal(payload, &event); err != nil {
log.Debugf("kiro: skipping malformed event: %v", err)
continue
}
// DIAGNOSTIC: Log all received event types for debugging
log.Debugf("kiro: parseEventStream received event type: %s", eventType)
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("kiro: parseEventStream event payload: %s", string(payload))
}
// Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
// These can appear as top-level fields or nested within the event
if errType, hasErrType := event["_type"].(string); hasErrType {
@@ -1568,7 +1617,7 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
errMsg = msg
}
log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg)
return "", nil, usageInfo, fmt.Errorf("kiro API error: %s - %s", errType, errMsg)
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg)
}
if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") {
// Generic error event
@@ -1581,7 +1630,18 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
}
}
log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg)
return "", nil, usageInfo, fmt.Errorf("kiro API error: %s", errMsg)
return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg)
}
// Extract stop_reason from various event formats
// Kiro/Amazon Q API may include stop_reason in different locations
if sr := getString(event, "stop_reason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason)
}
if sr := getString(event, "stopReason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason)
}
// Handle different event types
@@ -1596,6 +1656,15 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
if contentText, ok := assistantResp["content"].(string); ok {
content.WriteString(contentText)
}
// Extract stop_reason from assistantResponseEvent
if sr := getString(assistantResp, "stop_reason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason)
}
if sr := getString(assistantResp, "stopReason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason)
}
// Extract tool uses from response
if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok {
for _, tuRaw := range toolUsesRaw {
@@ -1661,6 +1730,17 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
if outputTokens, ok := event["outputTokens"].(float64); ok {
usageInfo.OutputTokens = int64(outputTokens)
}
case "messageStopEvent", "message_stop":
// Handle message stop events which may contain stop_reason
if sr := getString(event, "stop_reason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason)
}
if sr := getString(event, "stopReason"); sr != "" {
stopReason = sr
log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason)
}
}
// Also check nested supplementaryWebLinksEvent
@@ -1682,10 +1762,166 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroToolUse,
// Deduplicate all tool uses
toolUses = deduplicateToolUses(toolUses)
return cleanedContent, toolUses, usageInfo, nil
// Apply fallback logic for stop_reason if not provided by upstream
// Priority: upstream stopReason > tool_use detection > end_turn default
if stopReason == "" {
if len(toolUses) > 0 {
stopReason = "tool_use"
log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses))
} else {
stopReason = "end_turn"
log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn")
}
}
// Log warning if response was truncated due to max_tokens
if stopReason == "max_tokens" {
log.Warnf("kiro: response truncated due to max_tokens limit")
}
return cleanedContent, toolUses, usageInfo, stopReason, nil
}
// readEventStreamMessage reads and validates a single AWS Event Stream message.
// Returns the parsed message or a structured error for different failure modes.
// This function implements boundary protection and detailed error classification.
//
// AWS Event Stream binary format:
// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4)
// - Headers (variable): header entries
// - Payload (variable): JSON data
// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped)
func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) {
// Read prelude (first 12 bytes: total_len + headers_len + prelude_crc)
prelude := make([]byte, 12)
_, err := io.ReadFull(reader, prelude)
if err == io.EOF {
return nil, nil // Normal end of stream
}
if err != nil {
return nil, &EventStreamError{
Type: ErrStreamFatal,
Message: "failed to read prelude",
Cause: err,
}
}
totalLength := binary.BigEndian.Uint32(prelude[0:4])
headersLength := binary.BigEndian.Uint32(prelude[4:8])
// Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements)
// Boundary check: minimum frame size
if totalLength < minEventStreamFrameSize {
return nil, &EventStreamError{
Type: ErrStreamMalformed,
Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize),
}
}
// Boundary check: maximum message size
if totalLength > maxEventStreamMsgSize {
return nil, &EventStreamError{
Type: ErrStreamMalformed,
Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize),
}
}
// Boundary check: headers length within message bounds
// Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4)
// So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc)
if headersLength > totalLength-16 {
return nil, &EventStreamError{
Type: ErrStreamMalformed,
Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength),
}
}
// Read the rest of the message (total - 12 bytes already read)
remaining := make([]byte, totalLength-12)
_, err = io.ReadFull(reader, remaining)
if err != nil {
return nil, &EventStreamError{
Type: ErrStreamFatal,
Message: "failed to read message body",
Cause: err,
}
}
// Extract event type from headers
// Headers start at beginning of 'remaining', length is headersLength
var eventType string
if headersLength > 0 && headersLength <= uint32(len(remaining)) {
eventType = e.extractEventTypeFromBytes(remaining[:headersLength])
}
// Calculate payload boundaries
// Payload starts after headers, ends before message_crc (last 4 bytes)
payloadStart := headersLength
payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end
// Validate payload boundaries
if payloadStart >= payloadEnd {
// No payload, return empty message
return &eventStreamMessage{
EventType: eventType,
Payload: nil,
}, nil
}
payload := remaining[payloadStart:payloadEnd]
return &eventStreamMessage{
EventType: eventType,
Payload: payload,
}, nil
}
// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix)
func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string {
offset := 0
for offset < len(headers) {
if offset >= len(headers) {
break
}
nameLen := int(headers[offset])
offset++
if offset+nameLen > len(headers) {
break
}
name := string(headers[offset : offset+nameLen])
offset += nameLen
if offset >= len(headers) {
break
}
valueType := headers[offset]
offset++
if valueType == 7 { // String type
if offset+2 > len(headers) {
break
}
valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
offset += 2
if offset+valueLen > len(headers) {
break
}
value := string(headers[offset : offset+valueLen])
offset += valueLen
if name == ":event-type" {
return value
}
} else {
// Skip other types
break
}
}
return ""
}
// extractEventType extracts the event type from AWS Event Stream headers
// Note: This is the legacy version that expects headerBytes to include prelude CRC prefix
func (e *KiroExecutor) extractEventType(headerBytes []byte) string {
// Skip prelude CRC (4 bytes)
if len(headerBytes) < 4 {
@@ -1746,7 +1982,8 @@ func getString(m map[string]interface{}, key string) string {
// buildClaudeResponse constructs a Claude-compatible response.
// Supports tool_use blocks when tools are present in the response.
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail) []byte {
// stopReason is passed from upstream; fallback logic applied if empty.
func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
var contentBlocks []map[string]interface{}
// Extract thinking blocks and text from content
@@ -1782,10 +2019,18 @@ func (e *KiroExecutor) buildClaudeResponse(content string, toolUses []kiroToolUs
})
}
// Determine stop reason
stopReason := "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
// Use upstream stopReason; apply fallback logic if not provided
if stopReason == "" {
stopReason = "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
}
// Log warning if response was truncated due to max_tokens
if stopReason == "max_tokens" {
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
}
response := map[string]interface{}{
@@ -1906,10 +2151,12 @@ func (e *KiroExecutor) extractThinkingFromContent(content string) []map[string]i
// Supports tool calling - emits tool_use content blocks when tools are used.
// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent.
// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API).
// Extracts stop_reason from upstream events when available.
func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter) {
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
var totalUsage usage.Detail
var hasToolUses bool // Track if any tool uses were emitted
var hasToolUses bool // Track if any tool uses were emitted
var upstreamStopReason string // Track stop_reason from upstream events
// Tool use state tracking for input buffering and deduplication
processedIDs := make(map[string]bool)
@@ -1925,6 +2172,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
var accumulatedContent strings.Builder
accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations
// Real-time usage estimation state
// These track when to send periodic usage updates during streaming
var lastUsageUpdateLen int // Last accumulated content length when usage was sent
var lastUsageUpdateTime = time.Now() // Last time usage update was sent
var lastReportedOutputTokens int64 // Last reported output token count
// Translator param for maintaining tool call state across streaming events
// IMPORTANT: This must persist across all TranslateStream calls
var translatorParam any
@@ -1932,24 +2185,37 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
// Thinking mode state tracking - based on amq2api implementation
// Tracks whether we're inside a <thinking> block and handles partial tags
inThinkBlock := false
pendingStartTagChars := 0 // Number of chars that might be start of <thinking>
pendingEndTagChars := 0 // Number of chars that might be start of </thinking>
isThinkingBlockOpen := false // Track if thinking content block is open
thinkingBlockIndex := -1 // Index of the thinking content block
pendingStartTagChars := 0 // Number of chars that might be start of <thinking>
pendingEndTagChars := 0 // Number of chars that might be start of </thinking>
isThinkingBlockOpen := false // Track if thinking content block is open
thinkingBlockIndex := -1 // Index of the thinking content block
// Pre-calculate input tokens from request if possible
if enc, err := tokenizerForModel(model); err == nil {
// Try OpenAI format first, then fall back to raw byte count estimation
if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
totalUsage.InputTokens = inp
// Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback
if enc, err := getTokenizer(model); err == nil {
var inputTokens int64
var countMethod string
// Try Claude format first (Kiro uses Claude API format)
if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 {
inputTokens = inp
countMethod = "claude"
} else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
// Fallback to OpenAI format (for OpenAI-compatible requests)
inputTokens = inp
countMethod = "openai"
} else {
// Fallback: estimate from raw request size (roughly 4 chars per token)
totalUsage.InputTokens = int64(len(originalReq) / 4)
if totalUsage.InputTokens == 0 && len(originalReq) > 0 {
totalUsage.InputTokens = 1
// Final fallback: estimate from raw request size (roughly 4 chars per token)
inputTokens = int64(len(claudeBody) / 4)
if inputTokens == 0 && len(claudeBody) > 0 {
inputTokens = 1
}
countMethod = "estimate"
}
log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (request size: %d bytes)", totalUsage.InputTokens, len(originalReq))
totalUsage.InputTokens = inputTokens
log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)",
totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq))
}
contentBlockIndex := -1
@@ -1969,9 +2235,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
default:
}
prelude := make([]byte, 8)
_, err := io.ReadFull(reader, prelude)
if err == io.EOF {
msg, eventErr := e.readEventStreamMessage(reader)
if eventErr != nil {
// Log the error
log.Errorf("kiro: streamToChannel error: %v", eventErr)
// Send error to channel for client notification
out <- cliproxyexecutor.StreamChunk{Err: eventErr}
return
}
if msg == nil {
// Normal end of stream (EOF)
// Flush any incomplete tool use before ending stream
if currentToolUse != nil && !processedIDs[currentToolUse.toolUseID] {
log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.name, currentToolUse.toolUseID)
@@ -2069,44 +2343,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
}
break
}
if err != nil {
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read prelude: %w", err)}
return
}
totalLen := binary.BigEndian.Uint32(prelude[0:4])
if totalLen < 8 {
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("invalid message length: %d", totalLen)}
return
}
if totalLen > kiroMaxMessageSize {
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("message too large: %d bytes", totalLen)}
return
}
headersLen := binary.BigEndian.Uint32(prelude[4:8])
remaining := make([]byte, totalLen-8)
_, err = io.ReadFull(reader, remaining)
if err != nil {
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("failed to read message: %w", err)}
return
}
// Validate headersLen to prevent slice out of bounds
if headersLen+4 > uint32(len(remaining)) {
log.Warnf("kiro: invalid headersLen %d exceeds remaining buffer %d", headersLen, len(remaining))
eventType := msg.EventType
payload := msg.Payload
if len(payload) == 0 {
continue
}
eventType := e.extractEventType(remaining[:headersLen+4])
payloadStart := 4 + headersLen
payloadEnd := uint32(len(remaining)) - 4
if payloadStart >= payloadEnd {
continue
}
payload := remaining[payloadStart:payloadEnd]
appendAPIResponseChunk(ctx, e.cfg, payload)
var event map[string]interface{}
@@ -2115,12 +2357,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
continue
}
// DIAGNOSTIC: Log all received event types for debugging
log.Debugf("kiro: streamToChannel received event type: %s", eventType)
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("kiro: streamToChannel event payload: %s", string(payload))
}
// Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
// These can appear as top-level fields or nested within the event
if errType, hasErrType := event["_type"].(string); hasErrType {
@@ -2148,6 +2384,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
return
}
// Extract stop_reason from various event formats (streaming)
// Kiro/Amazon Q API may include stop_reason in different locations
if sr := getString(event, "stop_reason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason)
}
if sr := getString(event, "stopReason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason)
}
// Send message_start on first event
if !messageStartSent {
msgStart := e.buildClaudeMessageStartEvent(model, totalUsage.InputTokens)
@@ -2166,6 +2413,17 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
log.Debugf("kiro: streamToChannel ignoring followupPrompt event")
continue
case "messageStopEvent", "message_stop":
// Handle message stop events which may contain stop_reason
if sr := getString(event, "stop_reason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason)
}
if sr := getString(event, "stopReason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason)
}
case "assistantResponseEvent":
var contentDelta string
var toolUses []map[string]interface{}
@@ -2174,6 +2432,15 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
if c, ok := assistantResp["content"].(string); ok {
contentDelta = c
}
// Extract stop_reason from assistantResponseEvent
if sr := getString(assistantResp, "stop_reason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason)
}
if sr := getString(assistantResp, "stopReason"); sr != "" {
upstreamStopReason = sr
log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason)
}
// Extract tool uses from response
if tus, ok := assistantResp["toolUses"].([]interface{}); ok {
for _, tuRaw := range tus {
@@ -2199,11 +2466,6 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
// Handle text content with thinking mode support
if contentDelta != "" {
// DIAGNOSTIC: Check for thinking tags in response
if strings.Contains(contentDelta, "<thinking>") || strings.Contains(contentDelta, "</thinking>") {
log.Infof("kiro: DIAGNOSTIC - Found thinking tag in response (len: %d)", len(contentDelta))
}
// NOTE: Duplicate content filtering was removed because it incorrectly
// filtered out legitimate repeated content (like consecutive newlines "\n\n").
// Streaming naturally can have identical chunks that are valid content.
@@ -2211,6 +2473,52 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
outputLen += len(contentDelta)
// Accumulate content for streaming token calculation
accumulatedContent.WriteString(contentDelta)
// Real-time usage estimation: Check if we should send a usage update
// This helps clients track context usage during long thinking sessions
shouldSendUsageUpdate := false
if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold {
shouldSendUsageUpdate = true
} else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen {
shouldSendUsageUpdate = true
}
if shouldSendUsageUpdate {
// Calculate current output tokens using tiktoken
var currentOutputTokens int64
if enc, encErr := getTokenizer(model); encErr == nil {
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
currentOutputTokens = int64(tokenCount)
}
}
// Fallback to character estimation if tiktoken fails
if currentOutputTokens == 0 {
currentOutputTokens = int64(accumulatedContent.Len() / 4)
if currentOutputTokens == 0 {
currentOutputTokens = 1
}
}
// Only send update if token count has changed significantly (at least 10 tokens)
if currentOutputTokens > lastReportedOutputTokens+10 {
// Send ping event with usage information
// This is a non-blocking update that clients can optionally process
pingEvent := e.buildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
for _, chunk := range sseData {
if chunk != "" {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
}
}
lastReportedOutputTokens = currentOutputTokens
log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)",
totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len())
}
lastUsageUpdateLen = accumulatedContent.Len()
lastUsageUpdateTime = time.Now()
}
// Process content with thinking tag detection - based on amq2api implementation
// This handles <thinking> and </thinking> tags that may span across chunks
@@ -2577,10 +2885,10 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
}
// Streaming token calculation - calculate output tokens from accumulated content
// This provides more accurate token counting than simple character division
// Only use local estimation if server didn't provide usage (server-side usage takes priority)
if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 {
// Try to use tiktoken for accurate counting
if enc, err := tokenizerForModel(model); err == nil {
if enc, err := getTokenizer(model); err == nil {
if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
totalUsage.OutputTokens = int64(tokenCount)
log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens)
@@ -2609,10 +2917,21 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
}
totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens
// Determine stop reason based on whether tool uses were emitted
stopReason := "end_turn"
if hasToolUses {
stopReason = "tool_use"
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
stopReason := upstreamStopReason
if stopReason == "" {
if hasToolUses {
stopReason = "tool_use"
log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use")
} else {
stopReason = "end_turn"
log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn")
}
}
// Log warning if response was truncated due to max_tokens
if stopReason == "max_tokens" {
log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)")
}
// Send message_delta event
@@ -2758,6 +3077,24 @@ func (e *KiroExecutor) buildClaudeFinalEvent() []byte {
return []byte("event: message_stop\ndata: " + string(result))
}
// buildClaudePingEventWithUsage creates a ping event with embedded usage information.
// This is used for real-time usage estimation during streaming.
// The usage field is a non-standard extension that clients can optionally process.
// Clients that don't recognize the usage field will simply ignore it.
func (e *KiroExecutor) buildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
event := map[string]interface{}{
"type": "ping",
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
"total_tokens": inputTokens + outputTokens,
"estimated": true, // Flag to indicate this is an estimate, not final
},
}
result, _ := json.Marshal(event)
return []byte("event: ping\ndata: " + string(result))
}
// buildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
// This is used when streaming thinking content wrapped in <thinking> tags.
func (e *KiroExecutor) buildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
@@ -2837,10 +3174,21 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
// Also check if expires_at is now in the future with sufficient buffer
if expiresAt, ok := auth.Metadata["expires_at"].(string); ok {
if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil {
// If token expires more than 2 minutes from now, it's still valid
if time.Until(expTime) > 2*time.Minute {
// If token expires more than 5 minutes from now, it's still valid
if time.Until(expTime) > 5*time.Minute {
log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime))
return auth, nil
// CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks
// Without this, shouldRefresh() will return true again in 5 seconds
updated := auth.Clone()
// Set next refresh to 5 minutes before expiry, or at least 30 seconds from now
nextRefresh := expTime.Add(-5 * time.Minute)
minNextRefresh := time.Now().Add(30 * time.Second)
if nextRefresh.Before(minNextRefresh) {
nextRefresh = minNextRefresh
}
updated.NextRefreshAfter = nextRefresh
log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh))
return updated, nil
}
}
}
@@ -2924,9 +3272,9 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
updated.Attributes["profile_arn"] = tokenData.ProfileArn
}
// Set next refresh time to 30 minutes before expiry
// NextRefreshAfter is aligned with RefreshLead (5min)
if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil {
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
}
log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt)
@@ -2943,7 +3291,7 @@ func (e *KiroExecutor) streamEventStream(ctx context.Context, body io.Reader, c
var translatorParam any
// Pre-calculate input tokens from request if possible
if enc, err := tokenizerForModel(model); err == nil {
if enc, err := getTokenizer(model); err == nil {
// Try OpenAI format first, then fall back to raw byte count estimation
if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
totalUsage.InputTokens = inp

View File

@@ -137,15 +137,25 @@ func buildProxyTransport(proxyURL string) *http.Transport {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
// Set up a custom transport using the SOCKS5 dialer with optimized connection pooling
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
IdleConnTimeout: 90 * time.Second,
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
// Configure HTTP or HTTPS proxy with optimized connection pooling
transport = &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
IdleConnTimeout: 90 * time.Second,
}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil

View File

@@ -2,43 +2,107 @@ package executor
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
// tokenizerCache stores tokenizer instances to avoid repeated creation
var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
type TokenizerWrapper struct {
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
}
// Count returns the token count with adjustment factor applied
func (tw *TokenizerWrapper) Count(text string) (int, error) {
count, err := tw.Codec.Count(text)
if err != nil {
return 0, err
}
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
return int(float64(count) * tw.AdjustmentFactor), nil
}
return count, nil
}
// getTokenizer returns a cached tokenizer for the given model.
// This improves performance by avoiding repeated tokenizer creation.
func getTokenizer(model string) (*TokenizerWrapper, error) {
// Check cache first
if cached, ok := tokenizerCache.Load(model); ok {
return cached.(*TokenizerWrapper), nil
}
// Cache miss, create new tokenizer
wrapper, err := tokenizerForModel(model)
if err != nil {
return nil, err
}
// Store in cache (use LoadOrStore to handle race conditions)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
return actual.(*TokenizerWrapper), nil
}
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
func tokenizerForModel(model string) (tokenizer.Codec, error) {
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
// Claude models use cl100k_base with 1.1 adjustment factor
// because tiktoken may underestimate Claude's actual token count
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
}
var enc tokenizer.Codec
var err error
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
enc, err = tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
enc, err = tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
return tokenizer.ForModel(tokenizer.O1)
enc, err = tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
return tokenizer.ForModel(tokenizer.O3)
enc, err = tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
return tokenizer.ForModel(tokenizer.O4Mini)
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
default:
return tokenizer.Get(tokenizer.O200kBase)
enc, err = tokenizer.Get(tokenizer.O200kBase)
}
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
return int64(count), nil
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
// This handles Claude's message format with system, messages, and tools.
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(payload) == 0 {
return 0, nil
}
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
// Collect system prompt (can be string or array of content blocks)
collectClaudeSystem(root.Get("system"), &segments)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/http"
"net/url"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
@@ -36,15 +37,25 @@ func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return httpClient
}
// Set up a custom transport using the SOCKS5 dialer.
// Set up a custom transport using the SOCKS5 dialer with optimized connection pooling
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
IdleConnTimeout: 90 * time.Second,
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
// Configure HTTP or HTTPS proxy with optimized connection pooling
transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users
MaxConnsPerHost: 0, // No limit on max concurrent connections per host
IdleConnTimeout: 90 * time.Second,
}
}
}
// If a new transport was created, apply it to the HTTP client.

View File

@@ -47,8 +47,9 @@ func (a *KiroAuthenticator) Provider() string {
}
// RefreshLead indicates how soon before expiry a refresh should be attempted.
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
d := 30 * time.Minute
d := 5 * time.Minute
return &d
}
@@ -103,7 +104,8 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
"source": "aws-builder-id",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
// NextRefreshAfter is aligned with RefreshLead (5min)
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
}
if tokenData.Email != "" {
@@ -165,7 +167,8 @@ func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Con
"source": "google-oauth",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
// NextRefreshAfter is aligned with RefreshLead (5min)
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
}
if tokenData.Email != "" {
@@ -227,7 +230,8 @@ func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Con
"source": "github-oauth",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
// NextRefreshAfter is aligned with RefreshLead (5min)
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
}
if tokenData.Email != "" {
@@ -291,7 +295,8 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
"source": "kiro-ide-import",
"email": tokenData.Email,
},
NextRefreshAfter: expiresAt.Add(-30 * time.Minute),
// NextRefreshAfter is aligned with RefreshLead (5min)
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
}
// Display the email if extracted
@@ -351,7 +356,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
updated.Metadata["refresh_token"] = tokenData.RefreshToken
updated.Metadata["expires_at"] = tokenData.ExpiresAt
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
updated.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
// NextRefreshAfter is aligned with RefreshLead (5min)
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
return updated, nil
}

View File

@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
const (
refreshCheckInterval = 5 * time.Second
refreshPendingBackoff = time.Minute
refreshFailureBackoff = 5 * time.Minute
refreshFailureBackoff = 1 * time.Minute
quotaBackoffBase = time.Second
quotaBackoffMax = 30 * time.Minute
)
@@ -1471,7 +1471,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
updated.Runtime = auth.Runtime
}
updated.LastRefreshedAt = now
updated.NextRefreshAfter = time.Time{}
// Preserve NextRefreshAfter set by the Authenticator
// If the Authenticator set a reasonable refresh time, it should not be overwritten
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
updated.LastError = nil
updated.UpdatedAt = now
_, _ = m.Update(ctx, updated)