From d9c6317c84b4b19ee2090610bcc19c7a2d9a3b4e Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 21 Jan 2026 18:30:05 +0800 Subject: [PATCH] refactor(cache, translator): refine signature caching logic and tests, replace session-based logic with model group handling --- internal/cache/signature_cache.go | 89 +++++++++---------- internal/cache/signature_cache_test.go | 46 +++++----- .../claude/antigravity_claude_request.go | 48 +++++----- .../claude/antigravity_claude_request_test.go | 4 +- .../claude/antigravity_claude_response.go | 2 +- .../antigravity_claude_response_test.go | 16 ++-- .../gemini/antigravity_gemini_request.go | 50 ++++++----- 7 files changed, 129 insertions(+), 126 deletions(-) diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index ea98f8a0..af5371bf 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -3,7 +3,6 @@ package cache import ( "crypto/sha256" "encoding/hex" - "fmt" "strings" "sync" "time" @@ -25,18 +24,18 @@ const ( // MinValidSignatureLen is the minimum length for a signature to be considered valid MinValidSignatureLen = 50 - // SessionCleanupInterval controls how often stale sessions are purged - SessionCleanupInterval = 10 * time.Minute + // CacheCleanupInterval controls how often stale entries are purged + CacheCleanupInterval = 10 * time.Minute ) -// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry +// signatureCache stores signatures by model group -> textHash -> SignatureEntry var signatureCache sync.Map -// sessionCleanupOnce ensures the background cleanup goroutine starts only once -var sessionCleanupOnce sync.Once +// cacheCleanupOnce ensures the background cleanup goroutine starts only once +var cacheCleanupOnce sync.Once -// sessionCache is the inner map type -type sessionCache struct { +// groupCache is the inner map type +type groupCache struct { mu sync.RWMutex entries map[string]SignatureEntry } @@ -47,36 +46,36 @@ func hashText(text string) string { return hex.EncodeToString(h[:])[:SignatureTextHashLen] } -// getOrCreateSession gets or creates a session cache -func getOrCreateSession(sessionID string) *sessionCache { +// getOrCreateGroupCache gets or creates a cache bucket for a model group +func getOrCreateGroupCache(groupKey string) *groupCache { // Start background cleanup on first access - sessionCleanupOnce.Do(startSessionCleanup) + cacheCleanupOnce.Do(startCacheCleanup) - if val, ok := signatureCache.Load(sessionID); ok { - return val.(*sessionCache) + if val, ok := signatureCache.Load(groupKey); ok { + return val.(*groupCache) } - sc := &sessionCache{entries: make(map[string]SignatureEntry)} - actual, _ := signatureCache.LoadOrStore(sessionID, sc) - return actual.(*sessionCache) + sc := &groupCache{entries: make(map[string]SignatureEntry)} + actual, _ := signatureCache.LoadOrStore(groupKey, sc) + return actual.(*groupCache) } -// startSessionCleanup launches a background goroutine that periodically -// removes sessions where all entries have expired. -func startSessionCleanup() { +// startCacheCleanup launches a background goroutine that periodically +// removes caches where all entries have expired. +func startCacheCleanup() { go func() { - ticker := time.NewTicker(SessionCleanupInterval) + ticker := time.NewTicker(CacheCleanupInterval) defer ticker.Stop() for range ticker.C { - purgeExpiredSessions() + purgeExpiredCaches() } }() } -// purgeExpiredSessions removes sessions with no valid (non-expired) entries. -func purgeExpiredSessions() { +// purgeExpiredCaches removes caches with no valid (non-expired) entries. +func purgeExpiredCaches() { now := time.Now() signatureCache.Range(func(key, value any) bool { - sc := value.(*sessionCache) + sc := value.(*groupCache) sc.mu.Lock() // Remove expired entries for k, entry := range sc.entries { @@ -86,7 +85,7 @@ func purgeExpiredSessions() { } isEmpty := len(sc.entries) == 0 sc.mu.Unlock() - // Remove session if empty + // Remove cache bucket if empty if isEmpty { signatureCache.Delete(key) } @@ -94,7 +93,7 @@ func purgeExpiredSessions() { }) } -// CacheSignature stores a thinking signature for a given session and text. +// CacheSignature stores a thinking signature for a given model group and text. // Used for Claude models that require signed thinking blocks in multi-turn conversations. func CacheSignature(modelName, text, signature string) { if text == "" || signature == "" { @@ -104,9 +103,9 @@ func CacheSignature(modelName, text, signature string) { return } - text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) + groupKey := GetModelGroup(modelName) textHash := hashText(text) - sc := getOrCreateSession(textHash) + sc := getOrCreateGroupCache(groupKey) sc.mu.Lock() defer sc.mu.Unlock() @@ -116,26 +115,25 @@ func CacheSignature(modelName, text, signature string) { } } -// GetCachedSignature retrieves a cached signature for a given session and text. +// GetCachedSignature retrieves a cached signature for a given model group and text. // Returns empty string if not found or expired. func GetCachedSignature(modelName, text string) string { - family := GetModelGroup(modelName) + groupKey := GetModelGroup(modelName) if text == "" { - if family == "gemini" { + if groupKey == "gemini" { return "skip_thought_signature_validator" } return "" } - text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) - val, ok := signatureCache.Load(hashText(text)) + val, ok := signatureCache.Load(groupKey) if !ok { - if family == "gemini" { + if groupKey == "gemini" { return "skip_thought_signature_validator" } return "" } - sc := val.(*sessionCache) + sc := val.(*groupCache) textHash := hashText(text) @@ -145,7 +143,7 @@ func GetCachedSignature(modelName, text string) string { entry, exists := sc.entries[textHash] if !exists { sc.mu.Unlock() - if family == "gemini" { + if groupKey == "gemini" { return "skip_thought_signature_validator" } return "" @@ -153,7 +151,7 @@ func GetCachedSignature(modelName, text string) string { if now.Sub(entry.Timestamp) > SignatureCacheTTL { delete(sc.entries, textHash) sc.mu.Unlock() - if family == "gemini" { + if groupKey == "gemini" { return "skip_thought_signature_validator" } return "" @@ -167,22 +165,17 @@ func GetCachedSignature(modelName, text string) string { return entry.Signature } -// ClearSignatureCache clears signature cache for a specific session or all sessions. -func ClearSignatureCache(sessionID string) { - if sessionID != "" { - signatureCache.Range(func(key, _ any) bool { - kStr, ok := key.(string) - if ok && strings.HasSuffix(kStr, "#"+sessionID) { - signatureCache.Delete(key) - } - return true - }) - } else { +// ClearSignatureCache clears signature cache for a specific model group or all groups. +func ClearSignatureCache(modelName string) { + if modelName == "" { signatureCache.Range(func(key, _ any) bool { signatureCache.Delete(key) return true }) + return } + groupKey := GetModelGroup(modelName) + signatureCache.Delete(groupKey) } // HasValidSignature checks if a signature is valid (non-empty and long enough) diff --git a/internal/cache/signature_cache_test.go b/internal/cache/signature_cache_test.go index 9388c2e0..368d3195 100644 --- a/internal/cache/signature_cache_test.go +++ b/internal/cache/signature_cache_test.go @@ -21,33 +21,33 @@ func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) { } } -func TestCacheSignature_DifferentSessions(t *testing.T) { +func TestCacheSignature_DifferentModelGroups(t *testing.T) { ClearSignatureCache("") - text := "Same text in different sessions" + text := "Same text across models" sig1 := "signature1_1234567890123456789012345678901234567890123456" sig2 := "signature2_1234567890123456789012345678901234567890123456" - CacheSignature("test-model", text, sig1) - CacheSignature("test-model", text, sig2) + CacheSignature("claude-sonnet-4-5-thinking", text, sig1) + CacheSignature("gpt-4o", text, sig2) - if GetCachedSignature("test-model", text) != sig1 { - t.Error("Session-a signature mismatch") + if GetCachedSignature("claude-sonnet-4-5-thinking", text) != sig1 { + t.Error("Claude signature mismatch") } - if GetCachedSignature("test-model", text) != sig2 { - t.Error("Session-b signature mismatch") + if GetCachedSignature("gpt-4o", text) != sig2 { + t.Error("GPT signature mismatch") } } func TestCacheSignature_NotFound(t *testing.T) { ClearSignatureCache("") - // Non-existent session + // Non-existent cache entry if got := GetCachedSignature("test-model", "some text"); got != "" { - t.Errorf("Expected empty string for nonexistent session, got '%s'", got) + t.Errorf("Expected empty string for missing entry, got '%s'", got) } - // Existing session but different text + // Existing cache but different text CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890") if got := GetCachedSignature("test-model", "text-b"); got != "" { t.Errorf("Expected empty string for different text, got '%s'", got) @@ -58,7 +58,6 @@ func TestCacheSignature_EmptyInputs(t *testing.T) { ClearSignatureCache("") // All empty/invalid inputs should be no-ops - CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890") CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890") CacheSignature("test-model", "text", "") CacheSignature("test-model", "text", "short") // Too short @@ -81,20 +80,21 @@ func TestCacheSignature_ShortSignatureRejected(t *testing.T) { } } -func TestClearSignatureCache_SpecificSession(t *testing.T) { +func TestClearSignatureCache_ModelGroup(t *testing.T) { ClearSignatureCache("") - sig := "validSig1234567890123456789012345678901234567890123456" - CacheSignature("test-model", "text", sig) - CacheSignature("test-model", "text", sig) + sigClaude := "validSig1234567890123456789012345678901234567890123456" + sigGpt := "validSig9876543210987654321098765432109876543210987654" + CacheSignature("claude-sonnet-4-5-thinking", "text", sigClaude) + CacheSignature("gpt-4o", "text", sigGpt) - ClearSignatureCache("session-1") + ClearSignatureCache("claude-sonnet-4-5-thinking") - if got := GetCachedSignature("test-model", "text"); got != "" { - t.Error("session-1 should be cleared") + if got := GetCachedSignature("claude-sonnet-4-5-thinking", "text"); got != "" { + t.Error("Claude cache should be cleared") } - if got := GetCachedSignature("test-model", "text"); got != sig { - t.Error("session-2 should still exist") + if got := GetCachedSignature("gpt-4o", "text"); got != sigGpt { + t.Error("GPT cache should still exist") } } @@ -108,10 +108,10 @@ func TestClearSignatureCache_AllSessions(t *testing.T) { ClearSignatureCache("") if got := GetCachedSignature("test-model", "text"); got != "" { - t.Error("session-1 should be cleared") + t.Error("cache should be cleared") } if got := GetCachedSignature("test-model", "text"); got != "" { - t.Error("session-2 should be cleared") + t.Error("cache should be cleared") } } diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index e87a7d6b..bce76892 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -98,32 +98,38 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Use GetThinkingText to handle wrapped thinking objects thinkingText := thinking.GetThinkingText(contentResult) - // Always try cached signature first (more reliable than client-provided) - // Client may send stale or invalid signatures from different sessions signature := "" - if thinkingText != "" { - if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" { - signature = cachedSig - // log.Debugf("Using cached signature for thinking block") - } - } + signatureResult := contentResult.Get("signature") + hasClientSignature := signatureResult.Exists() && signatureResult.String() != "" - // Fallback to client signature only if cache miss and client signature is valid - if signature == "" { - signatureResult := contentResult.Get("signature") - clientSignature := "" - if signatureResult.Exists() && signatureResult.String() != "" { - arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2) - if len(arrayClientSignatures) == 2 { - if modelName == arrayClientSignatures[0] { - clientSignature = arrayClientSignatures[1] - } + // Only consider cached signatures when the client provided a signature. + // Unsigned thinking blocks must be dropped. + if hasClientSignature { + // Always try cached signature first (more reliable than client-provided) + // Client may send stale or invalid signatures from other requests + if thinkingText != "" { + if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" { + signature = cachedSig + // log.Debugf("Using cached signature for thinking block") } } - if cache.HasValidSignature(modelName, clientSignature) { - signature = clientSignature + + // Fallback to client signature only if cache miss and client signature is valid + if signature == "" { + clientSignature := "" + if signatureResult.Exists() && signatureResult.String() != "" { + arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2) + if len(arrayClientSignatures) == 2 { + if modelName == arrayClientSignatures[0] { + clientSignature = arrayClientSignatures[1] + } + } + } + if cache.HasValidSignature(modelName, clientSignature) { + signature = clientSignature + } + // log.Debugf("Using client-provided signature for thinking block") } - // log.Debugf("Using client-provided signature for thinking block") } // Store for subsequent tool_use in the same message diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 6eb58795..7831b8bd 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -78,9 +78,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) { validSignature := "abc123validSignature1234567890123456789012345678901234567890" thinkingText := "Let me think..." - // Pre-cache the signature (simulating a response from the same session) - // The session ID is derived from the first user message hash - // Since there's no user message in this test, we need to add one + // Pre-cache the signature (simulating a previous response for the same thinking text) inputJSON := []byte(`{ "model": "claude-sonnet-4-5-thinking", "messages": [ diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 57eca78c..3c834f6f 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -139,7 +139,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq if params.CurrentThinkingText.Len() > 0 { cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String()) - // log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len()) + // log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len()) params.CurrentThinkingText.Reset() } diff --git a/internal/translator/antigravity/claude/antigravity_claude_response_test.go b/internal/translator/antigravity/claude/antigravity_claude_response_test.go index 9dd1eedd..c561c557 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response_test.go @@ -12,10 +12,10 @@ import ( // Signature Caching Tests // ============================================================================ -func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) { +func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) { cache.ClearSignatureCache("") - // Request with user message - should derive session ID + // Request with user message - should initialize params requestJSON := []byte(`{ "messages": [ {"role": "user", "content": [{"type": "text", "text": "Hello world"}]} @@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) { ctx := context.Background() ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m) - // Verify session ID was set params := param.(*Params) - if params.SessionID == "" { - t.Error("SessionID should be derived from request") + if !params.HasFirstResponse { + t.Error("HasFirstResponse should be set after first chunk") + } + if params.CurrentThinkingText.Len() == 0 { + t.Error("Thinking text should be accumulated") } } @@ -130,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) { // Process thinking chunk ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m) params := param.(*Params) - sessionID := params.SessionID thinkingText := params.CurrentThinkingText.String() - if sessionID == "" { - t.Fatal("SessionID should be set") - } if thinkingText == "" { t.Fatal("Thinking text should be accumulated") } diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index 2ad9bd80..37346119 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -99,36 +99,44 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ } // Gemini-specific handling for non-Claude models: + // - Remove thinking parts entirely. // - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation. - // - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them). if !strings.Contains(modelName, "claude") { const skipSentinel = "skip_thought_signature_validator" gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool { - if content.Get("role").String() == "model" { - // First pass: collect indices of thinking parts to mark with skip sentinel - var thinkingIndicesToSkipSignature []int64 - content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool { - // Collect indices of thinking blocks to mark with skip sentinel - if part.Get("thought").Bool() { - thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int()) - } - // Add skip sentinel to functionCall parts - if part.Get("functionCall").Exists() { - existingSig := part.Get("thoughtSignature").String() - if existingSig == "" || len(existingSig) < 50 { - rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel) + if content.Get("role").String() != "model" { + return true + } + partsResult := content.Get("parts") + if !partsResult.IsArray() { + return true + } + + parts := partsResult.Array() + newParts := make([]interface{}, 0, len(parts)) + for _, part := range parts { + if part.Get("thought").Bool() { + continue + } + + partRaw := part.Raw + if part.Get("functionCall").Exists() { + existingSig := part.Get("thoughtSignature").String() + if existingSig == "" || len(existingSig) < 50 { + updatedPart, errSet := sjson.Set(partRaw, "thoughtSignature", skipSentinel) + if errSet != nil { + log.WithError(errSet).Debug("failed to set thoughtSignature on functionCall part") + } else { + partRaw = updatedPart } } - return true - }) - - // Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices - for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- { - idx := thinkingIndicesToSkipSignature[i] - rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel) } + + newParts = append(newParts, gjson.Parse(partRaw).Value()) } + + rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts", contentIdx.Int()), newParts) return true }) }