diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index 477b1245..ea98f8a0 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -96,17 +96,17 @@ func purgeExpiredSessions() { // CacheSignature stores a thinking signature for a given session and text. // Used for Claude models that require signed thinking blocks in multi-turn conversations. -func CacheSignature(modelName, sessionID, text, signature string) { - if sessionID == "" || text == "" || signature == "" { +func CacheSignature(modelName, text, signature string) { + if text == "" || signature == "" { return } if len(signature) < MinValidSignatureLen { return } - sc := getOrCreateSession(fmt.Sprintf("%s#%s", GetModelGroup(modelName), sessionID)) + text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) textHash := hashText(text) - + sc := getOrCreateSession(textHash) sc.mu.Lock() defer sc.mu.Unlock() @@ -118,13 +118,21 @@ func CacheSignature(modelName, sessionID, text, signature string) { // GetCachedSignature retrieves a cached signature for a given session and text. // Returns empty string if not found or expired. -func GetCachedSignature(modelName, sessionID, text string) string { - if sessionID == "" || text == "" { +func GetCachedSignature(modelName, text string) string { + family := GetModelGroup(modelName) + + if text == "" { + if family == "gemini" { + return "skip_thought_signature_validator" + } return "" } - - val, ok := signatureCache.Load(fmt.Sprintf("%s#%s", GetModelGroup(modelName), sessionID)) + text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) + val, ok := signatureCache.Load(hashText(text)) if !ok { + if family == "gemini" { + return "skip_thought_signature_validator" + } return "" } sc := val.(*sessionCache) @@ -137,11 +145,17 @@ func GetCachedSignature(modelName, sessionID, text string) string { entry, exists := sc.entries[textHash] if !exists { sc.mu.Unlock() + if family == "gemini" { + return "skip_thought_signature_validator" + } return "" } if now.Sub(entry.Timestamp) > SignatureCacheTTL { delete(sc.entries, textHash) sc.mu.Unlock() + if family == "gemini" { + return "skip_thought_signature_validator" + } return "" } @@ -156,7 +170,13 @@ func GetCachedSignature(modelName, sessionID, text string) string { // ClearSignatureCache clears signature cache for a specific session or all sessions. func ClearSignatureCache(sessionID string) { if sessionID != "" { - signatureCache.Delete(sessionID) + signatureCache.Range(func(key, _ any) bool { + kStr, ok := key.(string) + if ok && strings.HasSuffix(kStr, "#"+sessionID) { + signatureCache.Delete(key) + } + return true + }) } else { signatureCache.Range(func(key, _ any) bool { signatureCache.Delete(key) @@ -166,8 +186,8 @@ func ClearSignatureCache(sessionID string) { } // HasValidSignature checks if a signature is valid (non-empty and long enough) -func HasValidSignature(signature string) bool { - return signature != "" && len(signature) >= MinValidSignatureLen +func HasValidSignature(modelName, signature string) bool { + return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini") } func GetModelGroup(modelName string) string { diff --git a/internal/cache/signature_cache_test.go b/internal/cache/signature_cache_test.go index e4bddbe4..9388c2e0 100644 --- a/internal/cache/signature_cache_test.go +++ b/internal/cache/signature_cache_test.go @@ -8,15 +8,14 @@ import ( func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) { ClearSignatureCache("") - sessionID := "test-session-1" text := "This is some thinking text content" signature := "abc123validSignature1234567890123456789012345678901234567890" // Store signature - CacheSignature(sessionID, text, signature) + CacheSignature("test-model", text, signature) // Retrieve signature - retrieved := GetCachedSignature(sessionID, text) + retrieved := GetCachedSignature("test-model", text) if retrieved != signature { t.Errorf("Expected signature '%s', got '%s'", signature, retrieved) } @@ -29,13 +28,13 @@ func TestCacheSignature_DifferentSessions(t *testing.T) { sig1 := "signature1_1234567890123456789012345678901234567890123456" sig2 := "signature2_1234567890123456789012345678901234567890123456" - CacheSignature("session-a", text, sig1) - CacheSignature("session-b", text, sig2) + CacheSignature("test-model", text, sig1) + CacheSignature("test-model", text, sig2) - if GetCachedSignature("session-a", text) != sig1 { + if GetCachedSignature("test-model", text) != sig1 { t.Error("Session-a signature mismatch") } - if GetCachedSignature("session-b", text) != sig2 { + if GetCachedSignature("test-model", text) != sig2 { t.Error("Session-b signature mismatch") } } @@ -44,13 +43,13 @@ func TestCacheSignature_NotFound(t *testing.T) { ClearSignatureCache("") // Non-existent session - if got := GetCachedSignature("nonexistent", "some text"); got != "" { + if got := GetCachedSignature("test-model", "some text"); got != "" { t.Errorf("Expected empty string for nonexistent session, got '%s'", got) } // Existing session but different text - CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890") - if got := GetCachedSignature("session-x", "text-b"); got != "" { + CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890") + if got := GetCachedSignature("test-model", "text-b"); got != "" { t.Errorf("Expected empty string for different text, got '%s'", got) } } @@ -59,12 +58,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) { ClearSignatureCache("") // All empty/invalid inputs should be no-ops - CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890") - CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890") - CacheSignature("session", "text", "") - CacheSignature("session", "text", "short") // Too short + CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890") + CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890") + CacheSignature("test-model", "text", "") + CacheSignature("test-model", "text", "short") // Too short - if got := GetCachedSignature("session", "text"); got != "" { + if got := GetCachedSignature("test-model", "text"); got != "" { t.Errorf("Expected empty after invalid cache attempts, got '%s'", got) } } @@ -72,13 +71,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) { func TestCacheSignature_ShortSignatureRejected(t *testing.T) { ClearSignatureCache("") - sessionID := "test-short-sig" text := "Some text" shortSig := "abc123" // Less than 50 chars - CacheSignature(sessionID, text, shortSig) + CacheSignature("test-model", text, shortSig) - if got := GetCachedSignature(sessionID, text); got != "" { + if got := GetCachedSignature("test-model", text); got != "" { t.Errorf("Short signature should be rejected, got '%s'", got) } } @@ -87,15 +85,15 @@ func TestClearSignatureCache_SpecificSession(t *testing.T) { ClearSignatureCache("") sig := "validSig1234567890123456789012345678901234567890123456" - CacheSignature("session-1", "text", sig) - CacheSignature("session-2", "text", sig) + CacheSignature("test-model", "text", sig) + CacheSignature("test-model", "text", sig) ClearSignatureCache("session-1") - if got := GetCachedSignature("session-1", "text"); got != "" { + if got := GetCachedSignature("test-model", "text"); got != "" { t.Error("session-1 should be cleared") } - if got := GetCachedSignature("session-2", "text"); got != sig { + if got := GetCachedSignature("test-model", "text"); got != sig { t.Error("session-2 should still exist") } } @@ -104,15 +102,15 @@ func TestClearSignatureCache_AllSessions(t *testing.T) { ClearSignatureCache("") sig := "validSig1234567890123456789012345678901234567890123456" - CacheSignature("session-1", "text", sig) - CacheSignature("session-2", "text", sig) + CacheSignature("test-model", "text", sig) + CacheSignature("test-model", "text", sig) ClearSignatureCache("") - if got := GetCachedSignature("session-1", "text"); got != "" { + if got := GetCachedSignature("test-model", "text"); got != "" { t.Error("session-1 should be cleared") } - if got := GetCachedSignature("session-2", "text"); got != "" { + if got := GetCachedSignature("test-model", "text"); got != "" { t.Error("session-2 should be cleared") } } @@ -132,7 +130,7 @@ func TestHasValidSignature(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := HasValidSignature(tt.signature) + result := HasValidSignature("claude-sonnet-4-5-thinking", tt.signature) if result != tt.expected { t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected) } @@ -143,21 +141,19 @@ func TestHasValidSignature(t *testing.T) { func TestCacheSignature_TextHashCollisionResistance(t *testing.T) { ClearSignatureCache("") - sessionID := "hash-test-session" - // Different texts should produce different hashes text1 := "First thinking text" text2 := "Second thinking text" sig1 := "signature1_1234567890123456789012345678901234567890123456" sig2 := "signature2_1234567890123456789012345678901234567890123456" - CacheSignature(sessionID, text1, sig1) - CacheSignature(sessionID, text2, sig2) + CacheSignature("test-model", text1, sig1) + CacheSignature("test-model", text2, sig2) - if GetCachedSignature(sessionID, text1) != sig1 { + if GetCachedSignature("test-model", text1) != sig1 { t.Error("text1 signature mismatch") } - if GetCachedSignature(sessionID, text2) != sig2 { + if GetCachedSignature("test-model", text2) != sig2 { t.Error("text2 signature mismatch") } } @@ -165,13 +161,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) { func TestCacheSignature_UnicodeText(t *testing.T) { ClearSignatureCache("") - sessionID := "unicode-session" text := "한글 텍스트와 이모지 🎉 그리고 特殊文字" sig := "unicodeSig123456789012345678901234567890123456789012345" - CacheSignature(sessionID, text, sig) + CacheSignature("test-model", text, sig) - if got := GetCachedSignature(sessionID, text); got != sig { + if got := GetCachedSignature("test-model", text); got != sig { t.Errorf("Unicode text signature retrieval failed, got '%s'", got) } } @@ -179,15 +174,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) { func TestCacheSignature_Overwrite(t *testing.T) { ClearSignatureCache("") - sessionID := "overwrite-session" text := "Same text" sig1 := "firstSignature12345678901234567890123456789012345678901" sig2 := "secondSignature1234567890123456789012345678901234567890" - CacheSignature(sessionID, text, sig1) - CacheSignature(sessionID, text, sig2) // Overwrite + CacheSignature("test-model", text, sig1) + CacheSignature("test-model", text, sig2) // Overwrite - if got := GetCachedSignature(sessionID, text); got != sig2 { + if got := GetCachedSignature("test-model", text); got != sig2 { t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got) } } @@ -199,14 +193,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) { // This test verifies the expiration check exists // In a real scenario, we'd mock time.Now() - sessionID := "expiration-test" text := "text" sig := "validSig1234567890123456789012345678901234567890123456" - CacheSignature(sessionID, text, sig) + CacheSignature("test-model", text, sig) // Fresh entry should be retrievable - if got := GetCachedSignature(sessionID, text); got != sig { + if got := GetCachedSignature("test-model", text); got != sig { t.Errorf("Fresh entry should be retrievable, got '%s'", got) } diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index 2dfbcfc2..b94d7afe 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -4,6 +4,7 @@ package logging import ( + "errors" "fmt" "net/http" "runtime/debug" @@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool { // - gin.HandlerFunc: A middleware handler for panic recovery func GinLogrusRecovery() gin.HandlerFunc { return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { + if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) { + // Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs. + panic(http.ErrAbortHandler) + } + log.WithFields(log.Fields{ "panic": recovered, "stack": string(debug.Stack()), diff --git a/internal/logging/gin_logger_test.go b/internal/logging/gin_logger_test.go new file mode 100644 index 00000000..7de18338 --- /dev/null +++ b/internal/logging/gin_logger_test.go @@ -0,0 +1,60 @@ +package logging + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(GinLogrusRecovery()) + engine.GET("/abort", func(c *gin.Context) { + panic(http.ErrAbortHandler) + }) + + req := httptest.NewRequest(http.MethodGet, "/abort", nil) + recorder := httptest.NewRecorder() + + defer func() { + recovered := recover() + if recovered == nil { + t.Fatalf("expected panic, got nil") + } + err, ok := recovered.(error) + if !ok { + t.Fatalf("expected error panic, got %T", recovered) + } + if !errors.Is(err, http.ErrAbortHandler) { + t.Fatalf("expected ErrAbortHandler, got %v", err) + } + if err != http.ErrAbortHandler { + t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err) + } + }() + + engine.ServeHTTP(recorder, req) +} + +func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) { + gin.SetMode(gin.TestMode) + + engine := gin.New() + engine.Use(GinLogrusRecovery()) + engine.GET("/panic", func(c *gin.Context) { + panic("boom") + }) + + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + recorder := httptest.NewRecorder() + + engine.ServeHTTP(recorder, req) + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", recorder.Code) + } +} diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 55cc1626..897004fb 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1202,7 +1202,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", modelName) - if strings.Contains(modelName, "claude") { + if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") { strJSON := string(payload) paths := make([]string, 0) util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) @@ -1405,9 +1405,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload)) template, _ = sjson.Delete(template, "request.safetySettings") - template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") + // template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") - if strings.Contains(modelName, "claude") { + if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") { gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool { tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool { if funcDecl.Get("parametersJsonSchema").Exists() { @@ -1419,7 +1419,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b }) return true }) - } else { + } + + if !strings.Contains(modelName, "claude") { template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens") } diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 5b6ffe22..e87a7d6b 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -7,8 +7,6 @@ package claude import ( "bytes" - "crypto/sha256" - "encoding/hex" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" @@ -19,37 +17,6 @@ import ( "github.com/tidwall/sjson" ) -// deriveSessionID generates a stable session ID from the request. -// Uses the hash of the first user message to identify the conversation. -func deriveSessionID(rawJSON []byte) string { - userIDResult := gjson.GetBytes(rawJSON, "metadata.user_id") - if userIDResult.Exists() { - userID := userIDResult.String() - idx := strings.Index(userID, "session_") - if idx != -1 { - return userID[idx+8:] - } - } - messages := gjson.GetBytes(rawJSON, "messages") - if !messages.IsArray() { - return "" - } - for _, msg := range messages.Array() { - if msg.Get("role").String() == "user" { - content := msg.Get("content").String() - if content == "" { - // Try to get text from content array - content = msg.Get("content.0.text").String() - } - if content != "" { - h := sha256.Sum256([]byte(content)) - return hex.EncodeToString(h[:16]) - } - } - } - return "" -} - // ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format. // It extracts the model name, system instruction, message contents, and tool declarations // from the raw JSON request and returns them in the format expected by the Gemini CLI API. @@ -72,9 +39,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ enableThoughtTranslate := true rawJSON := bytes.Clone(inputRawJSON) - // Derive session ID for signature caching - sessionID := deriveSessionID(rawJSON) - // system instruction systemInstructionJSON := "" hasSystemInstruction := false @@ -137,8 +101,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Always try cached signature first (more reliable than client-provided) // Client may send stale or invalid signatures from different sessions signature := "" - if sessionID != "" && thinkingText != "" { - if cachedSig := cache.GetCachedSignature(modelName, sessionID, thinkingText); cachedSig != "" { + if thinkingText != "" { + if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" { signature = cachedSig // log.Debugf("Using cached signature for thinking block") } @@ -156,19 +120,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } } - if cache.HasValidSignature(clientSignature) { + if cache.HasValidSignature(modelName, clientSignature) { signature = clientSignature } // log.Debugf("Using client-provided signature for thinking block") } // Store for subsequent tool_use in the same message - if cache.HasValidSignature(signature) { + if cache.HasValidSignature(modelName, signature) { currentMessageThinkingSignature = signature } // Skip trailing unsigned thinking blocks on last assistant message - isUnsigned := !cache.HasValidSignature(signature) + isUnsigned := !cache.HasValidSignature(modelName, signature) // If unsigned, skip entirely (don't convert to text) // Claude requires assistant messages to start with thinking blocks when thinking is enabled @@ -223,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // This is the approach used in opencode-google-antigravity-auth for Gemini // and also works for Claude through Antigravity API const skipSentinel = "skip_thought_signature_validator" - if cache.HasValidSignature(currentMessageThinkingSignature) { + if cache.HasValidSignature(modelName, currentMessageThinkingSignature) { partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature) } else { // No valid signature - use skip sentinel to bypass validation diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 6e1fed1f..6eb58795 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -98,10 +98,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) { ] }`) - // Derive session ID and cache the signature - sessionID := deriveSessionID(inputJSON) - cache.CacheSignature(sessionID, thinkingText, validSignature) - defer cache.ClearSignatureCache(sessionID) + cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature) output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) @@ -266,10 +263,7 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) { ] }`) - // Derive session ID and cache the signature - sessionID := deriveSessionID(inputJSON) - cache.CacheSignature(sessionID, thinkingText, validSignature) - defer cache.ClearSignatureCache(sessionID) + cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature) output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) @@ -306,10 +300,7 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) { ] }`) - // Derive session ID and cache the signature - sessionID := deriveSessionID(inputJSON) - cache.CacheSignature(sessionID, thinkingText, validSignature) - defer cache.ClearSignatureCache(sessionID) + cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature) output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) @@ -517,10 +508,7 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin ] }`) - // Derive session ID and cache the signature - sessionID := deriveSessionID(inputJSON) - cache.CacheSignature(sessionID, thinkingText, validSignature) - defer cache.ClearSignatureCache(sessionID) + cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature) output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index e360f850..57eca78c 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -41,7 +41,6 @@ type Params struct { HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output // Signature caching support - SessionID string // Session ID derived from request for signature caching CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching } @@ -70,7 +69,6 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq HasFirstResponse: false, ResponseType: 0, ResponseIndex: 0, - SessionID: deriveSessionID(originalRequestRawJSON), } } modelName := gjson.GetBytes(requestRawJSON, "model").String() @@ -139,8 +137,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" { // log.Debug("Branch: signature_delta") - if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 { - cache.CacheSignature(modelName, params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String()) + if params.CurrentThinkingText.Len() > 0 { + cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String()) // log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len()) params.CurrentThinkingText.Reset() } diff --git a/internal/translator/antigravity/claude/antigravity_claude_response_test.go b/internal/translator/antigravity/claude/antigravity_claude_response_test.go index afc3d937..9dd1eedd 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response_test.go @@ -97,6 +97,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) { cache.ClearSignatureCache("") requestJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", "messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}] }`) @@ -143,7 +144,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) { ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m) // Verify signature was cached - cachedSig := cache.GetCachedSignature(sessionID, thinkingText) + cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText) if cachedSig != validSignature { t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig) } @@ -158,6 +159,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) cache.ClearSignatureCache("") requestJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", "messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}] }`) @@ -221,13 +223,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) // Process first thinking block ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m) params := param.(*Params) - sessionID := params.SessionID firstThinkingText := params.CurrentThinkingText.String() ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m) // Verify first signature cached - if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 { + if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 { t.Error("First thinking block signature should be cached") } @@ -241,76 +242,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m) // Verify second signature cached - if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 { + if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 { t.Error("Second thinking block signature should be cached") } } - -func TestDeriveSessionIDFromRequest(t *testing.T) { - tests := []struct { - name string - input []byte - wantEmpty bool - }{ - { - name: "valid user message", - input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`), - wantEmpty: false, - }, - { - name: "user message with content array", - input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`), - wantEmpty: false, - }, - { - name: "no user message", - input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`), - wantEmpty: true, - }, - { - name: "empty messages", - input: []byte(`{"messages": []}`), - wantEmpty: true, - }, - { - name: "no messages field", - input: []byte(`{}`), - wantEmpty: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := deriveSessionID(tt.input) - if tt.wantEmpty && result != "" { - t.Errorf("Expected empty session ID, got '%s'", result) - } - if !tt.wantEmpty && result == "" { - t.Error("Expected non-empty session ID") - } - }) - } -} - -func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) { - input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`) - - id1 := deriveSessionID(input) - id2 := deriveSessionID(input) - - if id1 != id2 { - t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2) - } -} - -func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) { - input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`) - input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`) - - id1 := deriveSessionID(input1) - id2 := deriveSessionID(input2) - - if id1 == id2 { - t.Error("Different messages should produce different session IDs") - } -} diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index a83c177d..2ad9bd80 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -8,6 +8,7 @@ package gemini import ( "bytes" "fmt" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" @@ -32,12 +33,12 @@ import ( // // Returns: // - []byte: The transformed request data in Gemini API format -func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte { +func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := bytes.Clone(inputRawJSON) template := "" template = `{"project":"","request":{},"model":""}` template, _ = sjson.SetRaw(template, "request", string(rawJSON)) - template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) + template, _ = sjson.Set(template, "model", modelName) template, _ = sjson.Delete(template, "request.model") template, errFixCLIToolResponse := fixCLIToolResponse(template) @@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) [] } } - // Gemini-specific handling: add skip_thought_signature_validator to functionCall parts - // and remove thinking blocks entirely (Gemini doesn't need to preserve them) - const skipSentinel = "skip_thought_signature_validator" + // Gemini-specific handling for non-Claude models: + // - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation. + // - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them). + if !strings.Contains(modelName, "claude") { + const skipSentinel = "skip_thought_signature_validator" - gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool { - if content.Get("role").String() == "model" { - // First pass: collect indices of thinking parts to remove - var thinkingIndicesToRemove []int64 - content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool { - // Mark thinking blocks for removal - if part.Get("thought").Bool() { - thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int()) - } - // Add skip sentinel to functionCall parts - if part.Get("functionCall").Exists() { - existingSig := part.Get("thoughtSignature").String() - if existingSig == "" || len(existingSig) < 50 { - rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel) + gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool { + if content.Get("role").String() == "model" { + // First pass: collect indices of thinking parts to mark with skip sentinel + var thinkingIndicesToSkipSignature []int64 + content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool { + // Collect indices of thinking blocks to mark with skip sentinel + if part.Get("thought").Bool() { + thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int()) } - } - return true - }) + // Add skip sentinel to functionCall parts + if part.Get("functionCall").Exists() { + existingSig := part.Get("thoughtSignature").String() + if existingSig == "" || len(existingSig) < 50 { + rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel) + } + } + return true + }) - // Remove thinking blocks in reverse order to preserve indices - for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- { - idx := thinkingIndicesToRemove[i] - rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx)) + // Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices + for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- { + idx := thinkingIndicesToSkipSignature[i] + rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel) + } } - } - return true - }) + return true + }) + } return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings") } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 41279977..5277b71b 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -298,6 +298,15 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte } functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse) out, _ = sjson.SetRaw(out, "contents.-1", functionContent) + + case "reasoning": + thoughtContent := `{"role":"model","parts":[]}` + thought := `{"text":"","thoughtSignature":"","thought":true}` + thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String()) + thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String()) + + thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought) + out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent) } } } else if input.Exists() && input.Type == gjson.String { diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index 5529d52a..985897fa 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -20,6 +20,7 @@ type geminiToResponsesState struct { // message aggregation MsgOpened bool + MsgClosed bool MsgIndex int CurrentMsgID string TextBuf strings.Builder @@ -29,6 +30,7 @@ type geminiToResponsesState struct { ReasoningOpened bool ReasoningIndex int ReasoningItemID string + ReasoningEnc string ReasoningBuf strings.Builder ReasoningClosed bool @@ -37,6 +39,7 @@ type geminiToResponsesState struct { FuncArgsBuf map[int]*strings.Builder FuncNames map[int]string FuncCallIDs map[int]string + FuncDone map[int]bool } // responseIDCounter provides a process-wide unique counter for synthesized response identifiers. @@ -45,6 +48,39 @@ var responseIDCounter uint64 // funcCallIDCounter provides a process-wide unique counter for function call identifiers. var funcCallIDCounter uint64 +func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte { + if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) { + return originalRequestRawJSON + } + if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) { + return requestRawJSON + } + return nil +} + +func unwrapRequestRoot(root gjson.Result) gjson.Result { + req := root.Get("request") + if !req.Exists() { + return root + } + if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() { + return req + } + return root +} + +func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result { + resp := root.Get("response") + if !resp.Exists() { + return root + } + // Vertex-style Gemini responses wrap the actual payload in a "response" object. + if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() { + return resp + } + return root +} + func emitEvent(event string, payload string) string { return fmt.Sprintf("event: %s\ndata: %s", event, payload) } @@ -56,18 +92,37 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string), + FuncDone: make(map[int]bool), } } st := (*param).(*geminiToResponsesState) + if st.FuncArgsBuf == nil { + st.FuncArgsBuf = make(map[int]*strings.Builder) + } + if st.FuncNames == nil { + st.FuncNames = make(map[int]string) + } + if st.FuncCallIDs == nil { + st.FuncCallIDs = make(map[int]string) + } + if st.FuncDone == nil { + st.FuncDone = make(map[int]bool) + } if bytes.HasPrefix(rawJSON, []byte("data:")) { rawJSON = bytes.TrimSpace(rawJSON[5:]) } + rawJSON = bytes.TrimSpace(rawJSON) + if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) { + return []string{} + } + root := gjson.ParseBytes(rawJSON) if !root.Exists() { return []string{} } + root = unwrapGeminiResponseRoot(root) var out []string nextSeq := func() int { st.Seq++; return st.Seq } @@ -98,19 +153,54 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID) itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex) + itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc) itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full) out = append(out, emitEvent("response.output_item.done", itemDone)) st.ReasoningClosed = true } + // Helper to finalize the assistant message in correct order. + // It emits response.output_text.done, response.content_part.done, + // and response.output_item.done exactly once. + finalizeMessage := func() { + if !st.MsgOpened || st.MsgClosed { + return + } + fullText := st.ItemTextBuf.String() + done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` + done, _ = sjson.Set(done, "sequence_number", nextSeq()) + done, _ = sjson.Set(done, "item_id", st.CurrentMsgID) + done, _ = sjson.Set(done, "output_index", st.MsgIndex) + done, _ = sjson.Set(done, "text", fullText) + out = append(out, emitEvent("response.output_text.done", done)) + partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` + partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) + partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID) + partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex) + partDone, _ = sjson.Set(partDone, "part.text", fullText) + out = append(out, emitEvent("response.content_part.done", partDone)) + final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}` + final, _ = sjson.Set(final, "sequence_number", nextSeq()) + final, _ = sjson.Set(final, "output_index", st.MsgIndex) + final, _ = sjson.Set(final, "item.id", st.CurrentMsgID) + final, _ = sjson.Set(final, "item.content.0.text", fullText) + out = append(out, emitEvent("response.output_item.done", final)) + + st.MsgClosed = true + } + // Initialize per-response fields and emit created/in_progress once if !st.Started { - if v := root.Get("responseId"); v.Exists() { - st.ResponseID = v.String() + st.ResponseID = root.Get("responseId").String() + if st.ResponseID == "" { + st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1)) + } + if !strings.HasPrefix(st.ResponseID, "resp_") { + st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID) } if v := root.Get("createTime"); v.Exists() { - if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil { + if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil { st.CreatedAt = t.Unix() } } @@ -143,15 +233,21 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // Ignore any late thought chunks after reasoning is finalized. return true } + if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature { + st.ReasoningEnc = sig.String() + } else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature { + st.ReasoningEnc = sig.String() + } if !st.ReasoningOpened { st.ReasoningOpened = true st.ReasoningIndex = st.NextIndex st.NextIndex++ st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex) - item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}` + item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}` item, _ = sjson.Set(item, "sequence_number", nextSeq()) item, _ = sjson.Set(item, "output_index", st.ReasoningIndex) item, _ = sjson.Set(item, "item.id", st.ReasoningItemID) + item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc) out = append(out, emitEvent("response.output_item.added", item)) partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}` partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq()) @@ -191,9 +287,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex) out = append(out, emitEvent("response.content_part.added", partAdded)) st.ItemTextBuf.Reset() - st.ItemTextBuf.WriteString(t.String()) } st.TextBuf.WriteString(t.String()) + st.ItemTextBuf.WriteString(t.String()) msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}` msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID) @@ -205,8 +301,10 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // Function call if fc := part.Get("functionCall"); fc.Exists() { - // Before emitting function-call outputs, finalize reasoning if open. + // Before emitting function-call outputs, finalize reasoning and the message (if open). + // Responses streaming requires message done events before the next output_item.added. finalizeReasoning() + finalizeMessage() name := fc.Get("name").String() idx := st.NextIndex st.NextIndex++ @@ -219,6 +317,14 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, } st.FuncNames[idx] = name + argsJSON := "{}" + if args := fc.Get("args"); args.Exists() { + argsJSON = args.Raw + } + if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" { + st.FuncArgsBuf[idx].WriteString(argsJSON) + } + // Emit item.added for function call item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}` item, _ = sjson.Set(item, "sequence_number", nextSeq()) @@ -228,10 +334,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, item, _ = sjson.Set(item, "item.name", name) out = append(out, emitEvent("response.output_item.added", item)) - // Emit arguments delta (full args in one chunk) - if args := fc.Get("args"); args.Exists() { - argsJSON := args.Raw - st.FuncArgsBuf[idx].WriteString(argsJSON) + // Emit arguments delta (full args in one chunk). + // When Gemini omits args, emit "{}" to keep Responses streaming event order consistent. + if argsJSON != "" { ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}` ad, _ = sjson.Set(ad, "sequence_number", nextSeq()) ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx])) @@ -240,6 +345,27 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, out = append(out, emitEvent("response.function_call_arguments.delta", ad)) } + // Gemini emits the full function call payload at once, so we can finalize it immediately. + if !st.FuncDone[idx] { + fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}` + fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq()) + fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx])) + fcDone, _ = sjson.Set(fcDone, "output_index", idx) + fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON) + out = append(out, emitEvent("response.function_call_arguments.done", fcDone)) + + itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}` + itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq()) + itemDone, _ = sjson.Set(itemDone, "output_index", idx) + itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx])) + itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON) + itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx]) + itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx]) + out = append(out, emitEvent("response.output_item.done", itemDone)) + + st.FuncDone[idx] = true + } + return true } @@ -251,28 +377,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" { // Finalize reasoning first to keep ordering tight with last delta finalizeReasoning() - // Close message output if opened - if st.MsgOpened { - fullText := st.ItemTextBuf.String() - done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}` - done, _ = sjson.Set(done, "sequence_number", nextSeq()) - done, _ = sjson.Set(done, "item_id", st.CurrentMsgID) - done, _ = sjson.Set(done, "output_index", st.MsgIndex) - done, _ = sjson.Set(done, "text", fullText) - out = append(out, emitEvent("response.output_text.done", done)) - partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}` - partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq()) - partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID) - partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex) - partDone, _ = sjson.Set(partDone, "part.text", fullText) - out = append(out, emitEvent("response.content_part.done", partDone)) - final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}` - final, _ = sjson.Set(final, "sequence_number", nextSeq()) - final, _ = sjson.Set(final, "output_index", st.MsgIndex) - final, _ = sjson.Set(final, "item.id", st.CurrentMsgID) - final, _ = sjson.Set(final, "item.content.0.text", fullText) - out = append(out, emitEvent("response.output_item.done", final)) - } + finalizeMessage() // Close function calls if len(st.FuncArgsBuf) > 0 { @@ -289,6 +394,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, } } for _, idx := range idxs { + if st.FuncDone[idx] { + continue + } args := "{}" if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 { args = b.String() @@ -308,6 +416,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx]) itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx]) out = append(out, emitEvent("response.output_item.done", itemDone)) + + st.FuncDone[idx] = true } } @@ -319,8 +429,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, completed, _ = sjson.Set(completed, "response.id", st.ResponseID) completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt) - if requestRawJSON != nil { - req := gjson.ParseBytes(requestRawJSON) + if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 { + req := unwrapRequestRoot(gjson.ParseBytes(reqJSON)) if v := req.Get("instructions"); v.Exists() { completed, _ = sjson.Set(completed, "response.instructions", v.String()) } @@ -383,41 +493,34 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, } } - // Compose outputs in encountered order: reasoning, message, function_calls + // Compose outputs in output_index order. outputsWrapper := `{"arr":[]}` - if st.ReasoningOpened { - item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}` - item, _ = sjson.Set(item, "id", st.ReasoningItemID) - item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String()) - outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) - } - if st.MsgOpened { - item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` - item, _ = sjson.Set(item, "id", st.CurrentMsgID) - item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String()) - outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) - } - if len(st.FuncArgsBuf) > 0 { - idxs := make([]int, 0, len(st.FuncArgsBuf)) - for idx := range st.FuncArgsBuf { - idxs = append(idxs, idx) + for idx := 0; idx < st.NextIndex; idx++ { + if st.ReasoningOpened && idx == st.ReasoningIndex { + item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}` + item, _ = sjson.Set(item, "id", st.ReasoningItemID) + item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc) + item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String()) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + continue } - for i := 0; i < len(idxs); i++ { - for j := i + 1; j < len(idxs); j++ { - if idxs[j] < idxs[i] { - idxs[i], idxs[j] = idxs[j], idxs[i] - } - } + if st.MsgOpened && idx == st.MsgIndex { + item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}` + item, _ = sjson.Set(item, "id", st.CurrentMsgID) + item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String()) + outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) + continue } - for _, idx := range idxs { - args := "" - if b := st.FuncArgsBuf[idx]; b != nil { + + if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" { + args := "{}" + if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 { args = b.String() } item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}` - item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx])) + item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID)) item, _ = sjson.Set(item, "arguments", args) - item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx]) + item, _ = sjson.Set(item, "call_id", callID) item, _ = sjson.Set(item, "name", st.FuncNames[idx]) outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item) } @@ -431,8 +534,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // input tokens = prompt + thoughts input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() completed, _ = sjson.Set(completed, "response.usage.input_tokens", input) - // cached_tokens not provided by Gemini; default to 0 for structure compatibility - completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0) + // cached token details: align with OpenAI "cached_tokens" semantics. + completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) // output tokens if v := um.Get("candidatesTokenCount"); v.Exists() { completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int()) @@ -460,6 +563,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object. func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { root := gjson.ParseBytes(rawJSON) + root = unwrapGeminiResponseRoot(root) // Base response scaffold resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}` @@ -478,15 +582,15 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // created_at: map from createTime if available createdAt := time.Now().Unix() if v := root.Get("createTime"); v.Exists() { - if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil { + if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil { createdAt = t.Unix() } } resp, _ = sjson.Set(resp, "created_at", createdAt) // Echo request fields when present; fallback model from response modelVersion - if len(requestRawJSON) > 0 { - req := gjson.ParseBytes(requestRawJSON) + if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 { + req := unwrapRequestRoot(gjson.ParseBytes(reqJSON)) if v := req.Get("instructions"); v.Exists() { resp, _ = sjson.Set(resp, "instructions", v.String()) } @@ -636,8 +740,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // input tokens = prompt + thoughts input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() resp, _ = sjson.Set(resp, "usage.input_tokens", input) - // cached_tokens not provided by Gemini; default to 0 for structure compatibility - resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", 0) + // cached token details: align with OpenAI "cached_tokens" semantics. + resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) // output tokens if v := um.Get("candidatesTokenCount"); v.Exists() { resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int()) diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response_test.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response_test.go new file mode 100644 index 00000000..9899c594 --- /dev/null +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response_test.go @@ -0,0 +1,353 @@ +package responses + +import ( + "context" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) { + t.Helper() + + lines := strings.Split(chunk, "\n") + if len(lines) < 2 { + t.Fatalf("unexpected SSE chunk: %q", chunk) + } + + event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:")) + dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:")) + if !gjson.Valid(dataLine) { + t.Fatalf("invalid SSE data JSON: %q", dataLine) + } + return event, gjson.Parse(dataLine) +} + +func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) { + // Vertex-style Gemini stream wraps the actual response payload under "response". + // This test ensures we unwrap and that output_text.done contains the full text. + in := []string{ + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + } + + originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`) + + var param any + var out []string + for _, line := range in { + out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), ¶m)...) + } + + var ( + gotTextDone bool + gotMessageDone bool + gotResponseDone bool + gotFuncDone bool + + textDone string + messageText string + responseID string + instructions string + cachedTokens int64 + + funcName string + funcArgs string + + posTextDone = -1 + posPartDone = -1 + posMessageDone = -1 + posFuncAdded = -1 + ) + + for i, chunk := range out { + ev, data := parseSSEEvent(t, chunk) + switch ev { + case "response.output_text.done": + gotTextDone = true + if posTextDone == -1 { + posTextDone = i + } + textDone = data.Get("text").String() + case "response.content_part.done": + if posPartDone == -1 { + posPartDone = i + } + case "response.output_item.done": + switch data.Get("item.type").String() { + case "message": + gotMessageDone = true + if posMessageDone == -1 { + posMessageDone = i + } + messageText = data.Get("item.content.0.text").String() + case "function_call": + gotFuncDone = true + funcName = data.Get("item.name").String() + funcArgs = data.Get("item.arguments").String() + } + case "response.output_item.added": + if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 { + posFuncAdded = i + } + case "response.completed": + gotResponseDone = true + responseID = data.Get("response.id").String() + instructions = data.Get("response.instructions").String() + cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int() + } + } + + if !gotTextDone { + t.Fatalf("missing response.output_text.done event") + } + if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 { + t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded) + } + if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) { + t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded) + } + if !gotMessageDone { + t.Fatalf("missing message response.output_item.done event") + } + if !gotFuncDone { + t.Fatalf("missing function_call response.output_item.done event") + } + if !gotResponseDone { + t.Fatalf("missing response.completed event") + } + + if textDone != "让我先了解" { + t.Fatalf("unexpected output_text.done text: got %q", textDone) + } + if messageText != "让我先了解" { + t.Fatalf("unexpected message done text: got %q", messageText) + } + + if responseID != "resp_req_vrtx_1" { + t.Fatalf("unexpected response id: got %q", responseID) + } + if instructions != "test instructions" { + t.Fatalf("unexpected instructions echo: got %q", instructions) + } + if cachedTokens != 2 { + t.Fatalf("unexpected cached token count: got %d", cachedTokens) + } + + if funcName != "mcp__serena__list_dir" { + t.Fatalf("unexpected function name: got %q", funcName) + } + if !gjson.Valid(funcArgs) { + t.Fatalf("invalid function arguments JSON: %q", funcArgs) + } + if gjson.Get(funcArgs, "recursive").Bool() != false { + t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value()) + } + if gjson.Get(funcArgs, "relative_path").String() != "internal" { + t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String()) + } +} + +func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) { + sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw==" + in := []string{ + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`, + } + + var param any + var out []string + for _, line := range in { + out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...) + } + + var ( + addedEnc string + doneEnc string + ) + for _, chunk := range out { + ev, data := parseSSEEvent(t, chunk) + switch ev { + case "response.output_item.added": + if data.Get("item.type").String() == "reasoning" { + addedEnc = data.Get("item.encrypted_content").String() + } + case "response.output_item.done": + if data.Get("item.type").String() == "reasoning" { + doneEnc = data.Get("item.encrypted_content").String() + } + } + } + + if addedEnc != sig { + t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc) + } + if doneEnc != sig { + t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc) + } +} + +func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) { + in := []string{ + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`, + } + + var param any + var out []string + for _, line := range in { + out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...) + } + + posAdded := []int{-1, -1, -1} + posArgsDelta := []int{-1, -1, -1} + posArgsDone := []int{-1, -1, -1} + posItemDone := []int{-1, -1, -1} + posCompleted := -1 + deltaByIndex := map[int]string{} + + for i, chunk := range out { + ev, data := parseSSEEvent(t, chunk) + switch ev { + case "response.output_item.added": + if data.Get("item.type").String() != "function_call" { + continue + } + idx := int(data.Get("output_index").Int()) + if idx >= 0 && idx < len(posAdded) { + posAdded[idx] = i + } + case "response.function_call_arguments.delta": + idx := int(data.Get("output_index").Int()) + if idx >= 0 && idx < len(posArgsDelta) { + posArgsDelta[idx] = i + deltaByIndex[idx] = data.Get("delta").String() + } + case "response.function_call_arguments.done": + idx := int(data.Get("output_index").Int()) + if idx >= 0 && idx < len(posArgsDone) { + posArgsDone[idx] = i + } + case "response.output_item.done": + if data.Get("item.type").String() != "function_call" { + continue + } + idx := int(data.Get("output_index").Int()) + if idx >= 0 && idx < len(posItemDone) { + posItemDone[idx] = i + } + case "response.completed": + posCompleted = i + + output := data.Get("response.output") + if !output.Exists() || !output.IsArray() { + t.Fatalf("missing response.output in response.completed") + } + if len(output.Array()) != 3 { + t.Fatalf("unexpected response.output length: got %d", len(output.Array())) + } + if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" { + t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw) + } + if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" { + t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw) + } + if data.Get("response.output.2.name").String() != "tool2" { + t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw) + } + if !gjson.Valid(data.Get("response.output.2.arguments").String()) { + t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String()) + } + } + } + + if posCompleted == -1 { + t.Fatalf("missing response.completed event") + } + for idx := 0; idx < 3; idx++ { + if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 { + t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx]) + } + if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) { + t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx]) + } + if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) { + t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx]) + } + } + + if deltaByIndex[0] != "{}" { + t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0]) + } + if deltaByIndex[1] != "{}" { + t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1]) + } + if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 { + t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2]) + } + if !(posItemDone[2] < posCompleted) { + t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted) + } +} + +func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) { + in := []string{ + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`, + `data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`, + } + + var param any + var out []string + for _, line := range in { + out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...) + } + + posFuncDone := -1 + posMsgAdded := -1 + posCompleted := -1 + + for i, chunk := range out { + ev, data := parseSSEEvent(t, chunk) + switch ev { + case "response.output_item.done": + if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 { + posFuncDone = i + } + case "response.output_item.added": + if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 { + posMsgAdded = i + } + case "response.completed": + posCompleted = i + if data.Get("response.output.0.type").String() != "function_call" { + t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw) + } + if data.Get("response.output.1.type").String() != "message" { + t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw) + } + if data.Get("response.output.1.content.0.text").String() != "hi" { + t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw) + } + } + } + + if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 { + t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted) + } + if !(posFuncDone < posMsgAdded) { + t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded) + } + if !(posMsgAdded < posCompleted) { + t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted) + } +}