diff --git a/config.example.yaml b/config.example.yaml index 057acf0d..0938753d 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -78,6 +78,11 @@ routing: # When true, enable authentication for the WebSocket API (/v1/ws). ws-auth: false +# Streaming behavior (SSE keep-alives + safe bootstrap retries). +# streaming: +# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives. +# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent. + # Gemini API keys # gemini-api-key: # - api-key: "AIzaSy...01" diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index f6f20d5c..7f019520 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -22,6 +22,21 @@ type SDKConfig struct { // Access holds request authentication provider configuration. Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` + + // Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries). + Streaming StreamingConfig `yaml:"streaming" json:"streaming"` +} + +// StreamingConfig holds server streaming behavior configuration. +type StreamingConfig struct { + // KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n"). + // nil means default (15 seconds). <= 0 disables keep-alives. + KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"` + + // BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent, + // to allow auth rotation / transient recovery. + // nil means default (2). 0 disables bootstrap retries. + BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"` } // AccessConfig groups request authentication providers. diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index d3b5ca22..2b4ec748 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -42,7 +42,7 @@ const ( antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" + defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64" antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second ) diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index ea4ba38e..80660378 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -35,6 +35,7 @@ type Params struct { CandidatesTokenCount int64 // Cached candidate token count from usage metadata ThoughtsTokenCount int64 // Cached thinking token count from usage metadata TotalTokenCount int64 // Cached total token count from usage metadata + CachedTokenCount int64 // Cached content token count (indicates prompt caching) HasSentFinalEvents bool // Indicates if final content/message events have been sent HasToolUse bool // Indicates if tool use was observed in the stream HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output @@ -274,6 +275,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int() params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int() params.TotalTokenCount = usageResult.Get("totalTokenCount").Int() + params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int() if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 { params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount if params.CandidatesTokenCount < 0 { @@ -322,6 +324,14 @@ func appendFinalEvents(params *Params, output *string, force bool) { *output = *output + "event: message_delta\n" *output = *output + "data: " delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens) + // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) + if params.CachedTokenCount > 0 { + var err error + delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount) + if err != nil { + log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) + } + } *output = *output + delta + "\n\n\n" params.HasSentFinalEvents = true @@ -361,6 +371,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int() totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int() + cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int() outputTokens := candidateTokens + thoughtTokens if outputTokens == 0 && totalTokens > 0 { outputTokens = totalTokens - promptTokens @@ -374,6 +385,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String()) responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens) responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens) + // Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working) + if cachedTokens > 0 { + var err error + responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens) + if err != nil { + log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err) + } + } contentArrayInitialized := false ensureContentArray := func() { diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index b699a4a0..7282ebc8 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -13,6 +13,8 @@ import ( "sync/atomic" "time" + log "github.com/sirupsen/logrus" + . "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -93,10 +95,19 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err) + } + } } // Process the main content part of the response. diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index 7b0c5571..feb80f65 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -244,7 +244,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo out, _ = sjson.SetRawBytes(out, "request.contents.-1", node) // Append a single tool content combining name + response per function - toolNode := []byte(`{"role":"tool","parts":[]}`) + toolNode := []byte(`{"role":"user","parts":[]}`) pp := 0 for _, fid := range fIDs { if name, ok := tcID2Name[fid]; ok { diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 195b0ae6..7b8c5c68 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -286,7 +286,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.SetRawBytes(out, "contents.-1", node) // Append a single tool content combining name + response per function - toolNode := []byte(`{"role":"tool","parts":[]}`) + toolNode := []byte(`{"role":"user","parts":[]}`) pp := 0 for _, fid := range fIDs { if name, ok := tcID2Name[fid]; ok { diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index 3c615b54..d710b1d6 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -96,10 +97,19 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err) + } + } } // Process the main content part of the response. @@ -240,10 +250,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } + // Include cached token count if present (indicates prompt caching is working) + if cachedTokenCount > 0 { + var err error + template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount) + if err != nil { + log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err) + } + } } // Process the main content part of the response. diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 8a4c4806..6554cc9a 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO // - c: The Gin context for the request. // - rawJSON: The raw JSON request body. func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - // Set up Server-Sent Events (SSE) headers for streaming response - // These headers are essential for maintaining a persistent connection - // and enabling real-time streaming of chat completions - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. // This is crucial for streaming as it allows immediate sending of data chunks flusher, ok := c.Writer.(http.Flusher) @@ -213,58 +204,82 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return -} + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } -func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - // OpenAI-style stream forwarding: write each SSE chunk and flush immediately. - // This guarantees clients see incremental output even for small responses. + // Peek at the first chunk to determine success or failure before setting headers for { select { case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) + cliCancel(c.Request.Context().Err()) return - - case chunk, ok := <-data: + case errMsg, ok := <-errChan: if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() flusher.Flush() - cancel(nil) + cliCancel(nil) return } + + // Success! Set headers now. + setSSEHeaders() + + // Write the first chunk if len(chunk) > 0 { _, _ = c.Writer.Write(chunk) flusher.Flush() } - case errMsg, ok := <-errs: - if !ok { - continue - } - if errMsg != nil { - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - c.Status(status) - - // An error occurred: emit as a proper SSE error event - errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) - flusher.Flush() - } - - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) + // Continue streaming the rest + h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) return - case <-time.After(500 * time.Millisecond): } } } +func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + if len(chunk) == 0 { + return + } + _, _ = c.Writer.Write(chunk) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return + } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + c.Status(status) + + errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) + }, + }) +} + type claudeErrorDetail struct { Type string `json:"type"` Message string `json:"message"` diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 5224faf8..ea78657d 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ } func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) { - continue + return } if !bytes.HasPrefix(chunk, []byte("data:")) { @@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 901421b5..2b17a9f2 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) { func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) { alt := h.GetAlt(c) - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -247,8 +240,65 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) - h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) - return + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Closed without data + if alt == "" { + setSSEHeaders() + } + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + if alt == "" { + setSSEHeaders() + } + + // Write first chunk + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } + flusher.Flush() + + // Continue + h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) + return + } + } } // handleCountTokens handles token counting requests for Gemini models. @@ -297,16 +347,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin } func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { _, _ = c.Writer.Write([]byte("data: ")) _, _ = c.Writer.Write(chunk) @@ -314,22 +363,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 839f060b..20f85585 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -9,8 +9,10 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -40,6 +42,117 @@ type ErrorDetail struct { Code string `json:"code,omitempty"` } +const idempotencyKeyMetadataKey = "idempotency_key" + +const ( + defaultStreamingKeepAliveSeconds = 0 + defaultStreamingBootstrapRetries = 0 +) + +// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body. +// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads. +func BuildErrorResponseBody(status int, errText string) []byte { + if status <= 0 { + status = http.StatusInternalServerError + } + if strings.TrimSpace(errText) == "" { + errText = http.StatusText(status) + } + + trimmed := strings.TrimSpace(errText) + if trimmed != "" && json.Valid([]byte(trimmed)) { + return []byte(trimmed) + } + + errType := "invalid_request_error" + var code string + switch status { + case http.StatusUnauthorized: + errType = "authentication_error" + code = "invalid_api_key" + case http.StatusForbidden: + errType = "permission_error" + code = "insufficient_quota" + case http.StatusTooManyRequests: + errType = "rate_limit_error" + code = "rate_limit_exceeded" + case http.StatusNotFound: + errType = "invalid_request_error" + code = "model_not_found" + default: + if status >= http.StatusInternalServerError { + errType = "server_error" + code = "internal_server_error" + } + } + + payload, err := json.Marshal(ErrorResponse{ + Error: ErrorDetail{ + Message: errText, + Type: errType, + Code: code, + }, + }) + if err != nil { + return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText)) + } + return payload +} + +// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server. +// Returning 0 disables keep-alives (default when unset). +func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration { + seconds := defaultStreamingKeepAliveSeconds + if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil { + seconds = *cfg.Streaming.KeepAliveSeconds + } + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent. +func StreamingBootstrapRetries(cfg *config.SDKConfig) int { + retries := defaultStreamingBootstrapRetries + if cfg != nil && cfg.Streaming.BootstrapRetries != nil { + retries = *cfg.Streaming.BootstrapRetries + } + if retries < 0 { + retries = 0 + } + return retries +} + +func requestExecutionMetadata(ctx context.Context) map[string]any { + // Idempotency-Key is an optional client-supplied header used to correlate retries. + // It is forwarded as execution metadata; when absent we generate a UUID. + key := "" + if ctx != nil { + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { + key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")) + } + } + if key == "" { + key = uuid.NewString() + } + return map[string]any{idempotencyKeyMetadataKey: key} +} + +func mergeMetadata(base, overlay map[string]any) map[string]any { + if len(base) == 0 && len(overlay) == 0 { + return nil + } + out := make(map[string]any, len(base)+len(overlay)) + for k, v := range base { + out[k] = v + } + for k, v := range overlay { + out[k] = v + } + return out +} + // BaseAPIHandler contains the handlers for API endpoints. // It holds a pool of clients to interact with the backend service and manages // load balancing, client selection, and configuration. @@ -183,6 +296,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -196,9 +310,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -225,6 +337,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -238,9 +351,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -270,6 +381,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl close(errChan) return nil, errChan } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -283,9 +395,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -310,31 +420,94 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl go func() { defer close(dataChan) defer close(errChan) - for chunk := range chunks { - if chunk.Err != nil { - status := http.StatusInternalServerError - if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil { - if code := se.StatusCode(); code > 0 { - status = code - } - } - var addon http.Header - if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } - errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon} - return + sentPayload := false + bootstrapRetries := 0 + maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + + bootstrapEligible := func(err error) bool { + status := statusFromError(err) + if status == 0 { + return true } - if len(chunk.Payload) > 0 { - dataChan <- cloneBytes(chunk.Payload) + switch status { + case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired, + http.StatusRequestTimeout, http.StatusTooManyRequests: + return true + default: + return status >= http.StatusInternalServerError + } + } + + outer: + for { + for { + var chunk coreexecutor.StreamChunk + var ok bool + if ctx != nil { + select { + case <-ctx.Done(): + return + case chunk, ok = <-chunks: + } + } else { + chunk, ok = <-chunks + } + if !ok { + return + } + if chunk.Err != nil { + streamErr := chunk.Err + // Safe bootstrap recovery: if the upstream fails before any payload bytes are sent, + // retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback. + if !sentPayload { + if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) { + bootstrapRetries++ + retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + if retryErr == nil { + chunks = retryChunks + continue outer + } + streamErr = retryErr + } + } + + status := http.StatusInternalServerError + if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon} + return + } + if len(chunk.Payload) > 0 { + sentPayload = true + dataChan <- cloneBytes(chunk.Payload) + } } } }() return dataChan, errChan } +func statusFromError(err error) int { + if err == nil { + return 0 + } + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + return code + } + } + return 0 +} + func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { // Resolve "auto" model to an actual available model first resolvedModelName := util.ResolveAutoModel(modelName) @@ -418,38 +591,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } } - // Prefer preserving upstream JSON error bodies when possible. - buildJSONBody := func() []byte { - trimmed := strings.TrimSpace(errText) - if trimmed != "" && json.Valid([]byte(trimmed)) { - return []byte(trimmed) - } - errType := "invalid_request_error" - switch status { - case http.StatusUnauthorized: - errType = "authentication_error" - case http.StatusForbidden: - errType = "permission_error" - case http.StatusTooManyRequests: - errType = "rate_limit_error" - default: - if status >= http.StatusInternalServerError { - errType = "server_error" - } - } - payload, err := json.Marshal(ErrorResponse{ - Error: ErrorDetail{ - Message: errText, - Type: errType, - }, - }) - if err != nil { - return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText)) - } - return payload - } - - body := buildJSONBody() + body := BuildErrorResponseBody(status, errText) c.Set("API_RESPONSE", bytes.Clone(body)) if !c.Writer.Written() { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go new file mode 100644 index 00000000..7f910447 --- /dev/null +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -0,0 +1,125 @@ +package handlers + +import ( + "context" + "net/http" + "sync" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +type failOnceStreamExecutor struct { + mu sync.Mutex + calls int +} + +func (e *failOnceStreamExecutor) Identifier() string { return "codex" } + +func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { + e.mu.Lock() + e.calls++ + call := e.calls + e.mu.Unlock() + + ch := make(chan coreexecutor.StreamChunk, 1) + if call == 1 { + ch <- coreexecutor.StreamChunk{ + Err: &coreauth.Error{ + Code: "unauthorized", + Message: "unauthorized", + Retryable: false, + HTTPStatus: http.StatusUnauthorized, + }, + } + close(ch) + return ch, nil + } + + ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} + close(ch) + return ch, nil +} + +func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *failOnceStreamExecutor) Calls() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.calls +} + +func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { + executor := &failOnceStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + bootstrapRetries := 1 + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: &bootstrapRetries, + }, + }, manager, nil) + dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if string(got) != "ok" { + t.Fatalf("expected payload ok, got %q", string(got)) + } + if executor.Calls() != 2 { + t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) + } +} diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index ae925f91..65936be7 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -11,7 +11,7 @@ import ( "encoding/json" "fmt" "net/http" - "time" + "sync" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible request func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -463,7 +458,55 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) - h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk to determine success or failure before setting headers + for { + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Commit to streaming headers. + setSSEHeaders() + + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + flusher.Flush() + + // Continue streaming the rest + h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + return + } + } } // handleCompletionsNonStreamingResponse handles non-streaming completions responses. @@ -500,11 +543,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -524,71 +562,109 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk for { select { case <-c.Request.Context().Done(): cliCancel(c.Request.Context().Err()) return - case chunk, isOk := <-dataChan: - if !isOk { + case errMsg, ok := <-errChan: + if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + setSSEHeaders() _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() - cliCancel() + cliCancel(nil) return } + + // Success! Set headers. + setSSEHeaders() + + // Write the first chunk converted := convertChatCompletionsStreamChunkToCompletions(chunk) if converted != nil { _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) flusher.Flush() } - case errMsg, isOk := <-errChan: - if !isOk { - continue - } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() - } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cliCancel(execErr) + + done := make(chan struct{}) + var doneOnce sync.Once + stop := func() { doneOnce.Do(func() { close(done) }) } + + convertedChan := make(chan []byte) + go func() { + defer close(convertedChan) + for { + select { + case <-done: + return + case chunk, ok := <-dataChan: + if !ok { + return + } + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted == nil { + continue + } + select { + case <-done: + return + case convertedChan <- converted: + } + } + } + }() + + h.handleStreamResult(c, flusher, func(err error) { + stop() + cliCancel(err) + }, convertedChan, errChan) return - case <-time.After(500 * time.Millisecond): } } } func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cancel(nil) + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { return } - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + }, + }) } diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index ace02313..b6d7c8f2 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -11,7 +11,6 @@ import ( "context" "fmt" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -149,46 +143,88 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return -} -func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk for { select { case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) + cliCancel(c.Request.Context().Err()) return - case chunk, ok := <-data: + case errMsg, ok := <-errChan: if !ok { + // Err channel closed cleanly; wait for data channel. + errChan = nil + continue + } + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send headers and done. + setSSEHeaders() _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() - cancel(nil) + cliCancel(nil) return } + // Success! Set headers. + setSSEHeaders() + + // Write first chunk logic (matching forwardResponsesStream) if bytes.HasPrefix(chunk, []byte("event:")) { _, _ = c.Writer.Write([]byte("\n")) } _, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue - } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() - } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) + + // Continue + h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) return - case <-time.After(500 * time.Millisecond): } } } + +func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + if bytes.HasPrefix(chunk, []byte("event:")) { + _, _ = c.Writer.Write([]byte("\n")) + } + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n")) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return + } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() + } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = c.Writer.Write([]byte("\n")) + }, + }) +} diff --git a/sdk/api/handlers/stream_forwarder.go b/sdk/api/handlers/stream_forwarder.go new file mode 100644 index 00000000..401baca8 --- /dev/null +++ b/sdk/api/handlers/stream_forwarder.go @@ -0,0 +1,121 @@ +package handlers + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" +) + +type StreamForwardOptions struct { + // KeepAliveInterval overrides the configured streaming keep-alive interval. + // If nil, the configured default is used. If set to <= 0, keep-alives are disabled. + KeepAliveInterval *time.Duration + + // WriteChunk writes a single data chunk to the response body. It should not flush. + WriteChunk func(chunk []byte) + + // WriteTerminalError writes an error payload to the response body when streaming fails + // after headers have already been committed. It should not flush. + WriteTerminalError func(errMsg *interfaces.ErrorMessage) + + // WriteDone optionally writes a terminal marker when the upstream data channel closes + // without an error (e.g. OpenAI's `[DONE]`). It should not flush. + WriteDone func() + + // WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush. + // When nil, a standard SSE comment heartbeat is used. + WriteKeepAlive func() +} + +func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) { + if c == nil { + return + } + if cancel == nil { + return + } + + writeChunk := opts.WriteChunk + if writeChunk == nil { + writeChunk = func([]byte) {} + } + + writeKeepAlive := opts.WriteKeepAlive + if writeKeepAlive == nil { + writeKeepAlive = func() { + _, _ = c.Writer.Write([]byte(": keep-alive\n\n")) + } + } + + keepAliveInterval := StreamingKeepAliveInterval(h.Cfg) + if opts.KeepAliveInterval != nil { + keepAliveInterval = *opts.KeepAliveInterval + } + var keepAlive *time.Ticker + var keepAliveC <-chan time.Time + if keepAliveInterval > 0 { + keepAlive = time.NewTicker(keepAliveInterval) + defer keepAlive.Stop() + keepAliveC = keepAlive.C + } + + var terminalErr *interfaces.ErrorMessage + for { + select { + case <-c.Request.Context().Done(): + cancel(c.Request.Context().Err()) + return + case chunk, ok := <-data: + if !ok { + // Prefer surfacing a terminal error if one is pending. + if terminalErr == nil { + select { + case errMsg, ok := <-errs: + if ok && errMsg != nil { + terminalErr = errMsg + } + default: + } + } + if terminalErr != nil { + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(terminalErr) + } + flusher.Flush() + cancel(terminalErr.Error) + return + } + if opts.WriteDone != nil { + opts.WriteDone() + } + flusher.Flush() + cancel(nil) + return + } + writeChunk(chunk) + flusher.Flush() + case errMsg, ok := <-errs: + if !ok { + continue + } + if errMsg != nil { + terminalErr = errMsg + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(errMsg) + flusher.Flush() + } + } + var execErr error + if errMsg != nil { + execErr = errMsg.Error + } + cancel(execErr) + return + case <-keepAliveC: + writeKeepAlive() + flusher.Flush() + } + } +} diff --git a/sdk/config/config.go b/sdk/config/config.go index 6e4efad5..b471e5e0 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider type Config = internalconfig.Config +type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement type AmpCode = internalconfig.AmpCode