From 9058d406a35d691b97324c6200839bb776d65916 Mon Sep 17 00:00:00 2001 From: evann Date: Fri, 19 Dec 2025 16:33:41 +0700 Subject: [PATCH 1/9] feat(antigravity): enhance prompt caching support and update agent version --- internal/runtime/executor/antigravity_executor.go | 13 ++++--------- .../claude/antigravity_claude_response.go | 11 +++++++++++ .../chat-completions/antigravity_openai_response.go | 5 +++++ .../chat-completions/gemini_openai_response.go | 10 ++++++++++ 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 8b4e37ee..4ccc65df 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -39,7 +39,7 @@ const ( antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" + defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64" antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second ) @@ -1145,10 +1145,11 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { if base := resolveCustomAntigravityBaseURL(auth); base != "" { return []string{base} } + // Production endpoint first (matches antigravity.js plugin behavior) + // Production may have better caching support return []string{ - antigravityBaseURLDaily, - // antigravityBaseURLAutopush, antigravityBaseURLProd, + antigravityBaseURLDaily, } } @@ -1183,7 +1184,6 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b template, _ = sjson.Set(template, "project", generateProjectID()) } template, _ = sjson.Set(template, "requestId", generateRequestID()) - template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") @@ -1218,11 +1218,6 @@ func generateRequestID() string { return "agent-" + uuid.NewString() } -func generateSessionID() string { - n := randSource.Int63n(9_000_000_000_000_000_000) - return "-" + strconv.FormatInt(n, 10) -} - func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 52fc358e..30d0b164 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -32,6 +32,7 @@ type Params struct { CandidatesTokenCount int64 // Cached candidate token count from usage metadata ThoughtsTokenCount int64 // Cached thinking token count from usage metadata TotalTokenCount int64 // Cached total token count from usage metadata + CachedTokenCount int64 // Cached content token count (indicates prompt caching) HasSentFinalEvents bool // Indicates if final content/message events have been sent HasToolUse bool // Indicates if tool use was observed in the stream HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output @@ -254,6 +255,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int() params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int() params.TotalTokenCount = usageResult.Get("totalTokenCount").Int() + params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int() if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 { params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount if params.CandidatesTokenCount < 0 { @@ -302,6 +304,10 @@ func appendFinalEvents(params *Params, output *string, force bool) { *output = *output + "event: message_delta\n" *output = *output + "data: " delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens) + // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) + if params.CachedTokenCount > 0 { + delta, _ = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount) + } *output = *output + delta + "\n\n\n" params.HasSentFinalEvents = true @@ -341,6 +347,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int() totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int() + cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int() outputTokens := candidateTokens + thoughtTokens if outputTokens == 0 && totalTokens > 0 { outputTokens = totalTokens - promptTokens @@ -354,6 +361,10 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens) responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) + // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) + if cachedTokens > 0 { + responseJSON, _ = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens) + } contentArrayInitialized := false ensureContentArray := func() { diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index 24694e1d..59a08621 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -94,10 +94,15 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + } } // Process the main content part of the response. diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index a1ebc855..e0ce4636 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -97,10 +97,15 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + } } // Process the main content part of the response. @@ -248,10 +253,15 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + } } // Process the main content part of the response. From 404546ce9346c24e3cfb4f4a317b66be10991441 Mon Sep 17 00:00:00 2001 From: evann Date: Fri, 19 Dec 2025 16:36:54 +0700 Subject: [PATCH 2/9] refactor(antigravity): regarding production endpoint caching --- internal/runtime/executor/antigravity_executor.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 4ccc65df..1aaf7ba0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1145,8 +1145,6 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { if base := resolveCustomAntigravityBaseURL(auth); base != "" { return []string{base} } - // Production endpoint first (matches antigravity.js plugin behavior) - // Production may have better caching support return []string{ antigravityBaseURLProd, antigravityBaseURLDaily, From bc6c4cdbfc68cecbf426742365f608af9be2c7d2 Mon Sep 17 00:00:00 2001 From: evann Date: Fri, 19 Dec 2025 16:49:50 +0700 Subject: [PATCH 3/9] feat(antigravity): add logging for cached token setting errors in responses --- .../claude/antigravity_claude_response.go | 13 +++++++++++-- .../chat-completions/antigravity_openai_response.go | 8 +++++++- .../chat-completions/gemini_openai_response.go | 13 +++++++++++-- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 30d0b164..bb06eba9 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -306,7 +307,11 @@ func appendFinalEvents(params *Params, output *string, force bool) { delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens) // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) if params.CachedTokenCount > 0 { - delta, _ = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount) + var err error + delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount) + if err != nil { + log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) + } } *output = *output + delta + "\n\n\n" @@ -363,7 +368,11 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) if cachedTokens > 0 { - responseJSON, _ = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens) + var err error + responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens) + if err != nil { + log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) + } } contentArrayInitialized := false diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index 59a08621..f9f5dea4 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -14,6 +14,8 @@ import ( "sync/atomic" "time" + log "github.com/sirupsen/logrus" + . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -101,7 +103,11 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq } // Include cached token count if present (indicates prompt caching is working) if cachedTokenCount > 0 { - template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err) + } } } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index e0ce4636..b2a44e9e 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -104,7 +105,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR } // Include cached token count if present (indicates prompt caching is working) if cachedTokenCount > 0 { - template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err) + } } } @@ -260,7 +265,11 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina } // Include cached token count if present (indicates prompt caching is working) if cachedTokenCount > 0 { - template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err) + } } } From a87f09bad21ce33c6b38c93028f66d99458116e5 Mon Sep 17 00:00:00 2001 From: Evan Nguyen Date: Sun, 21 Dec 2025 17:50:41 +0700 Subject: [PATCH 4/9] feat(antigravity): add session ID generation and mutex for random source --- .../runtime/executor/antigravity_executor.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 1aaf7ba0..afc6d24a 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -15,6 +15,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "github.com/google/uuid" @@ -44,7 +45,10 @@ const ( refreshSkew = 3000 * time.Second ) -var randSource = rand.New(rand.NewSource(time.Now().UnixNano())) +var ( + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + randSourceMutex sync.Mutex +) // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { @@ -1146,8 +1150,9 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { return []string{base} } return []string{ - antigravityBaseURLProd, antigravityBaseURLDaily, + // antigravityBaseURLAutopush, + antigravityBaseURLProd, } } @@ -1182,6 +1187,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b template, _ = sjson.Set(template, "project", generateProjectID()) } template, _ = sjson.Set(template, "requestId", generateRequestID()) + template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") @@ -1216,11 +1222,20 @@ func generateRequestID() string { return "agent-" + uuid.NewString() } +func generateSessionID() string { + randSourceMutex.Lock() + n := randSource.Int63n(9_000_000_000_000_000_000) + randSourceMutex.Unlock() + return "-" + strconv.FormatInt(n, 10) +} + func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} + randSourceMutex.Lock() adj := adjectives[randSource.Intn(len(adjectives))] noun := nouns[randSource.Intn(len(nouns))] + randSourceMutex.Unlock() randomPart := strings.ToLower(uuid.NewString())[:5] return adj + "-" + noun + "-" + randomPart } From 71a6dffbb6299c92646ed371becb16514c50bce6 Mon Sep 17 00:00:00 2001 From: gwizz Date: Mon, 22 Dec 2025 17:21:29 +1100 Subject: [PATCH 5/9] fix: improve streaming bootstrap and forwarding --- internal/config/sdk_config.go | 15 ++ sdk/api/handlers/claude/code_handlers.go | 111 ++++---- .../handlers/gemini/gemini-cli_handlers.go | 54 ++-- sdk/api/handlers/gemini/gemini_handlers.go | 112 +++++--- sdk/api/handlers/handlers.go | 245 +++++++++++++----- .../handlers_stream_bootstrap_test.go | 120 +++++++++ sdk/api/handlers/openai/openai_handlers.go | 202 ++++++++++----- .../openai/openai_responses_handlers.go | 102 +++++--- sdk/api/handlers/stream_forwarder.go | 121 +++++++++ sdk/config/config.go | 1 + 10 files changed, 804 insertions(+), 279 deletions(-) create mode 100644 sdk/api/handlers/handlers_stream_bootstrap_test.go create mode 100644 sdk/api/handlers/stream_forwarder.go diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index f6f20d5c..7f019520 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -22,6 +22,21 @@ type SDKConfig struct { // Access holds request authentication provider configuration. Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` + + // Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries). + Streaming StreamingConfig `yaml:"streaming" json:"streaming"` +} + +// StreamingConfig holds server streaming behavior configuration. +type StreamingConfig struct { + // KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n"). + // nil means default (15 seconds). <= 0 disables keep-alives. + KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"` + + // BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent, + // to allow auth rotation / transient recovery. + // nil means default (2). 0 disables bootstrap retries. + BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"` } // AccessConfig groups request authentication providers. diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 8a4c4806..bdf7c9c7 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO // - c: The Gin context for the request. // - rawJSON: The raw JSON request body. func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - // Set up Server-Sent Events (SSE) headers for streaming response - // These headers are essential for maintaining a persistent connection - // and enabling real-time streaming of chat completions - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. // This is crucial for streaming as it allows immediate sending of data chunks flusher, ok := c.Writer.(http.Flusher) @@ -213,56 +204,72 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk to determine success or failure before setting headers + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers now. + setSSEHeaders() + + // Write the first chunk + if len(chunk) > 0 { + _, _ = c.Writer.Write(chunk) + flusher.Flush() + } + + // Continue streaming the rest + h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - // OpenAI-style stream forwarding: write each SSE chunk and flush immediately. - // This guarantees clients see incremental output even for small responses. - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - - case chunk, ok := <-data: - if !ok { - flusher.Flush() - cancel(nil) + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + if len(chunk) == 0 { return } - if len(chunk) > 0 { - _, _ = c.Writer.Write(chunk) - flusher.Flush() + _, _ = c.Writer.Write(chunk) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + c.Status(status) - case errMsg, ok := <-errs: - if !ok { - continue - } - if errMsg != nil { - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - c.Status(status) - - // An error occurred: emit as a proper SSE error event - errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) - flusher.Flush() - } - - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) + }, + }) } type claudeErrorDetail struct { diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 5224faf8..ea78657d 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ } func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) { - continue + return } if !bytes.HasPrefix(chunk, []byte("data:")) { @@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 901421b5..baf68aac 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) { func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) { alt := h.GetAlt(c) - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -247,8 +240,57 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) - h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) - return + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Closed without data + if alt == "" { + setSSEHeaders() + } + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + if alt == "" { + setSSEHeaders() + } + + // Write first chunk + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } + flusher.Flush() + + // Continue + h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) + } } // handleCountTokens handles token counting requests for Gemini models. @@ -297,16 +339,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin } func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { _, _ = c.Writer.Write([]byte("data: ")) _, _ = c.Writer.Write(chunk) @@ -314,22 +355,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index e5b4fc93..5d33fe0e 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -9,8 +9,10 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -40,6 +42,115 @@ type ErrorDetail struct { Code string `json:"code,omitempty"` } +const idempotencyKeyMetadataKey = "idempotency_key" + +const ( + defaultStreamingKeepAliveSeconds = 15 + defaultStreamingBootstrapRetries = 2 +) + +// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body. +// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads. +func BuildErrorResponseBody(status int, errText string) []byte { + if status <= 0 { + status = http.StatusInternalServerError + } + if strings.TrimSpace(errText) == "" { + errText = http.StatusText(status) + } + + trimmed := strings.TrimSpace(errText) + if trimmed != "" && json.Valid([]byte(trimmed)) { + return []byte(trimmed) + } + + errType := "invalid_request_error" + var code string + switch status { + case http.StatusUnauthorized: + errType = "authentication_error" + code = "invalid_api_key" + case http.StatusForbidden: + errType = "permission_error" + code = "insufficient_quota" + case http.StatusTooManyRequests: + errType = "rate_limit_error" + code = "rate_limit_exceeded" + case http.StatusNotFound: + errType = "invalid_request_error" + code = "model_not_found" + default: + if status >= http.StatusInternalServerError { + errType = "server_error" + code = "internal_server_error" + } + } + + payload, err := json.Marshal(ErrorResponse{ + Error: ErrorDetail{ + Message: errText, + Type: errType, + Code: code, + }, + }) + if err != nil { + return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText)) + } + return payload +} + +// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server. +// Returning 0 disables keep-alives. +func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration { + seconds := defaultStreamingKeepAliveSeconds + if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil { + seconds = *cfg.Streaming.KeepAliveSeconds + } + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent. +func StreamingBootstrapRetries(cfg *config.SDKConfig) int { + retries := defaultStreamingBootstrapRetries + if cfg != nil && cfg.Streaming.BootstrapRetries != nil { + retries = *cfg.Streaming.BootstrapRetries + } + if retries < 0 { + retries = 0 + } + return retries +} + +func requestExecutionMetadata(ctx context.Context) map[string]any { + key := "" + if ctx != nil { + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { + key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")) + } + } + if key == "" { + key = uuid.NewString() + } + return map[string]any{idempotencyKeyMetadataKey: key} +} + +func mergeMetadata(base, overlay map[string]any) map[string]any { + if len(base) == 0 && len(overlay) == 0 { + return nil + } + out := make(map[string]any, len(base)+len(overlay)) + for k, v := range base { + out[k] = v + } + for k, v := range overlay { + out[k] = v + } + return out +} + // BaseAPIHandler contains the handlers for API endpoints. // It holds a pool of clients to interact with the backend service and manages // load balancing, client selection, and configuration. @@ -182,6 +293,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -195,9 +307,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -224,6 +334,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -237,9 +348,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -269,6 +378,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl close(errChan) return nil, errChan } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -282,9 +392,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -309,31 +417,81 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl go func() { defer close(dataChan) defer close(errChan) - for chunk := range chunks { - if chunk.Err != nil { - status := http.StatusInternalServerError - if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil { - if code := se.StatusCode(); code > 0 { - status = code - } - } - var addon http.Header - if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } - errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon} - return + sentPayload := false + bootstrapRetries := 0 + maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + + bootstrapEligible := func(err error) bool { + status := statusFromError(err) + if status == 0 { + return true } - if len(chunk.Payload) > 0 { - dataChan <- cloneBytes(chunk.Payload) + switch status { + case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired, + http.StatusRequestTimeout, http.StatusTooManyRequests: + return true + default: + return status >= http.StatusInternalServerError } } + + outer: + for { + for chunk := range chunks { + if chunk.Err != nil { + streamErr := chunk.Err + // Safe bootstrap recovery: if the upstream fails before any payload bytes are sent, + // retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback. + if !sentPayload { + if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) { + bootstrapRetries++ + retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + if retryErr == nil { + chunks = retryChunks + continue outer + } + streamErr = retryErr + } + } + + status := http.StatusInternalServerError + if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon} + return + } + if len(chunk.Payload) > 0 { + sentPayload = true + dataChan <- cloneBytes(chunk.Payload) + } + } + return + } }() return dataChan, errChan } +func statusFromError(err error) int { + if err == nil { + return 0 + } + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + return code + } + } + return 0 +} + func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { // Resolve "auto" model to an actual available model first resolvedModelName := util.ResolveAutoModel(modelName) @@ -417,38 +575,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } } - // Prefer preserving upstream JSON error bodies when possible. - buildJSONBody := func() []byte { - trimmed := strings.TrimSpace(errText) - if trimmed != "" && json.Valid([]byte(trimmed)) { - return []byte(trimmed) - } - errType := "invalid_request_error" - switch status { - case http.StatusUnauthorized: - errType = "authentication_error" - case http.StatusForbidden: - errType = "permission_error" - case http.StatusTooManyRequests: - errType = "rate_limit_error" - default: - if status >= http.StatusInternalServerError { - errType = "server_error" - } - } - payload, err := json.Marshal(ErrorResponse{ - Error: ErrorDetail{ - Message: errText, - Type: errType, - }, - }) - if err != nil { - return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText)) - } - return payload - } - - body := buildJSONBody() + body := BuildErrorResponseBody(status, errText) c.Set("API_RESPONSE", bytes.Clone(body)) if !c.Writer.Written() { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go new file mode 100644 index 00000000..cd2fdf4d --- /dev/null +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -0,0 +1,120 @@ +package handlers + +import ( + "context" + "net/http" + "sync" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +type failOnceStreamExecutor struct { + mu sync.Mutex + calls int +} + +func (e *failOnceStreamExecutor) Identifier() string { return "codex" } + +func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { + e.mu.Lock() + e.calls++ + call := e.calls + e.mu.Unlock() + + ch := make(chan coreexecutor.StreamChunk, 1) + if call == 1 { + ch <- coreexecutor.StreamChunk{ + Err: &coreauth.Error{ + Code: "unauthorized", + Message: "unauthorized", + Retryable: false, + HTTPStatus: http.StatusUnauthorized, + }, + } + close(ch) + return ch, nil + } + + ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} + close(ch) + return ch, nil +} + +func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *failOnceStreamExecutor) Calls() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.calls +} + +func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { + executor := &failOnceStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager, nil) + dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if string(got) != "ok" { + t.Fatalf("expected payload ok, got %q", string(got)) + } + if executor.Calls() != 2 { + t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) + } +} diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index ae925f91..d5962ea7 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -11,7 +11,7 @@ import ( "encoding/json" "fmt" "net/http" - "time" + "sync" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible request func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -463,7 +458,47 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) - h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk to determine success or failure before setting headers + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Commit to streaming headers. + setSSEHeaders() + + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + flusher.Flush() + + // Continue streaming the rest + h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } // handleCompletionsNonStreamingResponse handles non-streaming completions responses. @@ -500,11 +535,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -524,71 +554,101 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") - for { - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case chunk, isOk := <-dataChan: - if !isOk { - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cliCancel() - return - } - converted := convertChatCompletionsStreamChunkToCompletions(chunk) - if converted != nil { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) - flusher.Flush() - } - case errMsg, isOk := <-errChan: - if !isOk { - continue - } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() - } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cliCancel(execErr) - return - case <-time.After(500 * time.Millisecond): + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) } + return + case chunk, ok := <-dataChan: + if !ok { + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + setSSEHeaders() + + // Write the first chunk + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted != nil { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) + flusher.Flush() + } + + done := make(chan struct{}) + var doneOnce sync.Once + stop := func() { doneOnce.Do(func() { close(done) }) } + + convertedChan := make(chan []byte) + go func() { + defer close(convertedChan) + for { + select { + case <-done: + return + case chunk, ok := <-dataChan: + if !ok { + return + } + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted == nil { + continue + } + select { + case <-done: + return + case convertedChan <- converted: + } + } + } + }() + + h.handleStreamResult(c, flusher, func(err error) { + stop() + cliCancel(err) + }, convertedChan, errChan) } } func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cancel(nil) + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { return } - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + }, + }) } diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index ace02313..dd63deeb 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -11,7 +11,6 @@ import ( "context" "fmt" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -149,46 +143,80 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send headers and done. + setSSEHeaders() + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + setSSEHeaders() + + // Write first chunk logic (matching forwardResponsesStream) + if bytes.HasPrefix(chunk, []byte("event:")) { + _, _ = c.Writer.Write([]byte("\n")) + } + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + + // Continue + h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - cancel(nil) - return - } - + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { if bytes.HasPrefix(chunk, []byte("event:")) { _, _ = c.Writer.Write([]byte("\n")) } _, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write([]byte("\n")) - - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = c.Writer.Write([]byte("\n")) + }, + }) } diff --git a/sdk/api/handlers/stream_forwarder.go b/sdk/api/handlers/stream_forwarder.go new file mode 100644 index 00000000..401baca8 --- /dev/null +++ b/sdk/api/handlers/stream_forwarder.go @@ -0,0 +1,121 @@ +package handlers + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" +) + +type StreamForwardOptions struct { + // KeepAliveInterval overrides the configured streaming keep-alive interval. + // If nil, the configured default is used. If set to <= 0, keep-alives are disabled. + KeepAliveInterval *time.Duration + + // WriteChunk writes a single data chunk to the response body. It should not flush. + WriteChunk func(chunk []byte) + + // WriteTerminalError writes an error payload to the response body when streaming fails + // after headers have already been committed. It should not flush. + WriteTerminalError func(errMsg *interfaces.ErrorMessage) + + // WriteDone optionally writes a terminal marker when the upstream data channel closes + // without an error (e.g. OpenAI's `[DONE]`). It should not flush. + WriteDone func() + + // WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush. + // When nil, a standard SSE comment heartbeat is used. + WriteKeepAlive func() +} + +func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) { + if c == nil { + return + } + if cancel == nil { + return + } + + writeChunk := opts.WriteChunk + if writeChunk == nil { + writeChunk = func([]byte) {} + } + + writeKeepAlive := opts.WriteKeepAlive + if writeKeepAlive == nil { + writeKeepAlive = func() { + _, _ = c.Writer.Write([]byte(": keep-alive\n\n")) + } + } + + keepAliveInterval := StreamingKeepAliveInterval(h.Cfg) + if opts.KeepAliveInterval != nil { + keepAliveInterval = *opts.KeepAliveInterval + } + var keepAlive *time.Ticker + var keepAliveC <-chan time.Time + if keepAliveInterval > 0 { + keepAlive = time.NewTicker(keepAliveInterval) + defer keepAlive.Stop() + keepAliveC = keepAlive.C + } + + var terminalErr *interfaces.ErrorMessage + for { + select { + case <-c.Request.Context().Done(): + cancel(c.Request.Context().Err()) + return + case chunk, ok := <-data: + if !ok { + // Prefer surfacing a terminal error if one is pending. + if terminalErr == nil { + select { + case errMsg, ok := <-errs: + if ok && errMsg != nil { + terminalErr = errMsg + } + default: + } + } + if terminalErr != nil { + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(terminalErr) + } + flusher.Flush() + cancel(terminalErr.Error) + return + } + if opts.WriteDone != nil { + opts.WriteDone() + } + flusher.Flush() + cancel(nil) + return + } + writeChunk(chunk) + flusher.Flush() + case errMsg, ok := <-errs: + if !ok { + continue + } + if errMsg != nil { + terminalErr = errMsg + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(errMsg) + flusher.Flush() + } + } + var execErr error + if errMsg != nil { + execErr = errMsg.Error + } + cancel(execErr) + return + case <-keepAliveC: + writeKeepAlive() + flusher.Flush() + } + } +} diff --git a/sdk/config/config.go b/sdk/config/config.go index 6e4efad5..b471e5e0 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider type Config = internalconfig.Config +type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement type AmpCode = internalconfig.AmpCode From 4442574e53becb7850a0778de18d715157bdde05 Mon Sep 17 00:00:00 2001 From: gwizz Date: Tue, 23 Dec 2025 00:37:55 +1100 Subject: [PATCH 6/9] fix: stop streaming loop on context cancel --- sdk/api/handlers/handlers.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 5d33fe0e..50005055 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -437,7 +437,21 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl outer: for { - for chunk := range chunks { + for { + var chunk coreexecutor.StreamChunk + var ok bool + if ctx != nil { + select { + case <-ctx.Done(): + return + case chunk, ok = <-chunks: + } + } else { + chunk, ok = <-chunks + } + if !ok { + return + } if chunk.Err != nil { streamErr := chunk.Err // Safe bootstrap recovery: if the upstream fails before any payload bytes are sent, @@ -474,7 +488,6 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl dataChan <- cloneBytes(chunk.Payload) } } - return } }() return dataChan, errChan From 5bf89dd757bcc591ae1359aff4b50d935188673a Mon Sep 17 00:00:00 2001 From: gwizz Date: Tue, 23 Dec 2025 00:53:18 +1100 Subject: [PATCH 7/9] fix: keep streaming defaults legacy-safe --- config.example.yaml | 5 +++++ sdk/api/handlers/handlers.go | 8 +++++--- sdk/api/handlers/handlers_stream_bootstrap_test.go | 7 ++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 1e084cb4..aca7f4e1 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -69,6 +69,11 @@ quota-exceeded: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false +# Streaming behavior (SSE keep-alives + safe bootstrap retries). +# streaming: +# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives. +# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent. + # Gemini API keys # gemini-api-key: # - api-key: "AIzaSy...01" diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 50005055..7857f736 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -45,8 +45,8 @@ type ErrorDetail struct { const idempotencyKeyMetadataKey = "idempotency_key" const ( - defaultStreamingKeepAliveSeconds = 15 - defaultStreamingBootstrapRetries = 2 + defaultStreamingKeepAliveSeconds = 0 + defaultStreamingBootstrapRetries = 0 ) // BuildErrorResponseBody builds an OpenAI-compatible JSON error response body. @@ -100,7 +100,7 @@ func BuildErrorResponseBody(status int, errText string) []byte { } // StreamingKeepAliveInterval returns the SSE keep-alive interval for this server. -// Returning 0 disables keep-alives. +// Returning 0 disables keep-alives (default when unset). func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration { seconds := defaultStreamingKeepAliveSeconds if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil { @@ -125,6 +125,8 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int { } func requestExecutionMetadata(ctx context.Context) map[string]any { + // Idempotency-Key is an optional client-supplied header used to correlate retries. + // It is forwarded as execution metadata; when absent we generate a UUID. key := "" if ctx != nil { if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index cd2fdf4d..7f910447 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -94,7 +94,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { registry.GetGlobalRegistry().UnregisterClient(auth2.ID) }) - handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager, nil) + bootstrapRetries := 1 + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: &bootstrapRetries, + }, + }, manager, nil) dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") From f413feec618ebe3288d98f2449bfb8f0ab21b6a2 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 24 Dec 2025 04:07:24 +0800 Subject: [PATCH 8/9] refactor(handlers): streamline error and data channel handling in streaming logic Improved consistency across OpenAI, Claude, and Gemini handlers by replacing initial `select` statement with a `for` loop for better readability and error-handling robustness. --- sdk/api/handlers/claude/code_handlers.go | 68 ++++--- sdk/api/handlers/gemini/gemini_handlers.go | 78 ++++---- sdk/api/handlers/openai/openai_handlers.go | 174 ++++++++++-------- .../openai/openai_responses_handlers.go | 70 +++---- 4 files changed, 215 insertions(+), 175 deletions(-) diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index bdf7c9c7..6554cc9a 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -212,39 +212,47 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ } // Peek at the first chunk to determine success or failure before setting headers - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case errMsg := <-errChan: - // Upstream failed immediately. Return proper error status and JSON. - h.WriteErrorResponse(c, errMsg) - if errMsg != nil { - cliCancel(errMsg.Error) - } else { - cliCancel(nil) - } - return - case chunk, ok := <-dataChan: - if !ok { - // Stream closed without data? Send DONE or just headers. + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers now. setSSEHeaders() - flusher.Flush() - cliCancel(nil) + + // Write the first chunk + if len(chunk) > 0 { + _, _ = c.Writer.Write(chunk) + flusher.Flush() + } + + // Continue streaming the rest + h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) return } - - // Success! Set headers now. - setSSEHeaders() - - // Write the first chunk - if len(chunk) > 0 { - _, _ = c.Writer.Write(chunk) - flusher.Flush() - } - - // Continue streaming the rest - h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) } } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index baf68aac..2b17a9f2 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -249,47 +249,55 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName } // Peek at the first chunk - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case errMsg := <-errChan: - // Upstream failed immediately. Return proper error status and JSON. - h.WriteErrorResponse(c, errMsg) - if errMsg != nil { - cliCancel(errMsg.Error) - } else { - cliCancel(nil) - } - return - case chunk, ok := <-dataChan: - if !ok { - // Closed without data + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Closed without data + if alt == "" { + setSSEHeaders() + } + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. if alt == "" { setSSEHeaders() } + + // Write first chunk + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } flusher.Flush() - cliCancel(nil) + + // Continue + h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) return } - - // Success! Set headers. - if alt == "" { - setSSEHeaders() - } - - // Write first chunk - if alt == "" { - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - } else { - _, _ = c.Writer.Write(chunk) - } - flusher.Flush() - - // Continue - h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) } } diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index d5962ea7..65936be7 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -467,37 +467,45 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt } // Peek at the first chunk to determine success or failure before setting headers - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case errMsg := <-errChan: - // Upstream failed immediately. Return proper error status and JSON. - h.WriteErrorResponse(c, errMsg) - if errMsg != nil { - cliCancel(errMsg.Error) - } else { - cliCancel(nil) - } - return - case chunk, ok := <-dataChan: - if !ok { - // Stream closed without data? Send DONE or just headers. + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Commit to streaming headers. setSSEHeaders() - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) flusher.Flush() - cliCancel(nil) + + // Continue streaming the rest + h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) return } - - // Success! Commit to streaming headers. - setSSEHeaders() - - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) - flusher.Flush() - - // Continue streaming the rest - h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) } } @@ -562,69 +570,77 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra } // Peek at the first chunk - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case errMsg := <-errChan: - h.WriteErrorResponse(c, errMsg) - if errMsg != nil { - cliCancel(errMsg.Error) - } else { - cliCancel(nil) - } - return - case chunk, ok := <-dataChan: - if !ok { - setSSEHeaders() - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cliCancel(nil) + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) return - } + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } - // Success! Set headers. - setSSEHeaders() + // Success! Set headers. + setSSEHeaders() - // Write the first chunk - converted := convertChatCompletionsStreamChunkToCompletions(chunk) - if converted != nil { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) - flusher.Flush() - } + // Write the first chunk + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted != nil { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) + flusher.Flush() + } - done := make(chan struct{}) - var doneOnce sync.Once - stop := func() { doneOnce.Do(func() { close(done) }) } + done := make(chan struct{}) + var doneOnce sync.Once + stop := func() { doneOnce.Do(func() { close(done) }) } - convertedChan := make(chan []byte) - go func() { - defer close(convertedChan) - for { - select { - case <-done: - return - case chunk, ok := <-dataChan: - if !ok { - return - } - converted := convertChatCompletionsStreamChunkToCompletions(chunk) - if converted == nil { - continue - } + convertedChan := make(chan []byte) + go func() { + defer close(convertedChan) + for { select { case <-done: return - case convertedChan <- converted: + case chunk, ok := <-dataChan: + if !ok { + return + } + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted == nil { + continue + } + select { + case <-done: + return + case convertedChan <- converted: + } } } - } - }() + }() - h.handleStreamResult(c, flusher, func(err error) { - stop() - cliCancel(err) - }, convertedChan, errChan) + h.handleStreamResult(c, flusher, func(err error) { + stop() + cliCancel(err) + }, convertedChan, errChan) + return + } } } func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index dd63deeb..b6d7c8f2 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -152,42 +152,50 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ } // Peek at the first chunk - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case errMsg := <-errChan: - // Upstream failed immediately. Return proper error status and JSON. - h.WriteErrorResponse(c, errMsg) - if errMsg != nil { - cliCancel(errMsg.Error) - } else { - cliCancel(nil) - } - return - case chunk, ok := <-dataChan: - if !ok { - // Stream closed without data? Send headers and done. + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send headers and done. + setSSEHeaders() + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. setSSEHeaders() + + // Write first chunk logic (matching forwardResponsesStream) + if bytes.HasPrefix(chunk, []byte("event:")) { + _, _ = c.Writer.Write([]byte("\n")) + } + _, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() - cliCancel(nil) + + // Continue + h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) return } - - // Success! Set headers. - setSSEHeaders() - - // Write first chunk logic (matching forwardResponsesStream) - if bytes.HasPrefix(chunk, []byte("event:")) { - _, _ = c.Writer.Write([]byte("\n")) - } - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - - // Continue - h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) } } From 66769ec657e0a636625a57456d8c0183d4fcf20f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 24 Dec 2025 04:24:07 +0800 Subject: [PATCH 9/9] fix(translators): update role from `tool` to `user` in Gemini and Gemini-CLI requests --- .../openai/chat-completions/gemini-cli_openai_request.go | 2 +- .../gemini/openai/chat-completions/gemini_openai_request.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index 7b0c5571..feb80f65 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -244,7 +244,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) // Append a single tool content combining name + response per function - toolNode := []byte(`{"role":"tool","parts":[]}`) + toolNode := []byte(`{"role":"user","parts":[]}`) pp := 0 for _, fid := range fIDs { if name, ok := tcID2Name[fid]; ok { diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 195b0ae6..7b8c5c68 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -286,7 +286,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetRawBytes(out, "contents.-1", node) // Append a single tool content combining name + response per function - toolNode := []byte(`{"role":"tool","parts":[]}`) + toolNode := []byte(`{"role":"user","parts":[]}`) pp := 0 for _, fid := range fIDs { if name, ok := tcID2Name[fid]; ok {