From 76aa917882acb78eb98d08b32ce35354ba2f162d Mon Sep 17 00:00:00 2001 From: edlsh Date: Sat, 28 Feb 2026 22:47:04 -0500 Subject: [PATCH] Optimize cache-control JSON mutations in Claude executor --- internal/runtime/executor/claude_executor.go | 446 +++++++++++-------- 1 file changed, 258 insertions(+), 188 deletions(-) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 483a4830..0845d168 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -1147,9 +1148,10 @@ func generateBillingHeader(payload []byte) string { } // checkSystemInstructionsWithMode injects Claude Code-style system blocks: -// system[0]: billing header (no cache_control) -// system[1]: agent identifier (no cache_control) -// system[2..]: user system messages (cache_control added when missing) +// +// system[0]: billing header (no cache_control) +// system[1]: agent identifier (no cache_control) +// system[2..]: user system messages (cache_control added when missing) func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { system := gjson.GetBytes(payload, "system") @@ -1332,6 +1334,180 @@ func countCacheControls(payload []byte) int { return count } +func parsePayloadObject(payload []byte) (map[string]any, bool) { + if len(payload) == 0 { + return nil, false + } + var root map[string]any + if err := json.Unmarshal(payload, &root); err != nil { + return nil, false + } + return root, true +} + +func marshalPayloadObject(original []byte, root map[string]any) []byte { + if root == nil { + return original + } + out, err := json.Marshal(root) + if err != nil { + return original + } + return out +} + +func asObject(v any) (map[string]any, bool) { + obj, ok := v.(map[string]any) + return obj, ok +} + +func asArray(v any) ([]any, bool) { + arr, ok := v.([]any) + return arr, ok +} + +func countCacheControlsMap(root map[string]any) int { + count := 0 + + if system, ok := asArray(root["system"]); ok { + for _, item := range system { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + + if tools, ok := asArray(root["tools"]); ok { + for _, item := range tools { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + + if messages, ok := asArray(root["messages"]); ok { + for _, msg := range messages { + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + } + + return count +} + +func normalizeTTLForBlock(obj map[string]any, seen5m *bool) { + ccRaw, exists := obj["cache_control"] + if !exists { + return + } + cc, ok := asObject(ccRaw) + if !ok { + *seen5m = true + return + } + ttlRaw, ttlExists := cc["ttl"] + ttl, ttlIsString := ttlRaw.(string) + if !ttlExists || !ttlIsString || ttl != "1h" { + *seen5m = true + return + } + if *seen5m { + delete(cc, "ttl") + } +} + +func findLastCacheControlIndex(arr []any) int { + last := -1 + for idx, item := range arr { + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + last = idx + } + } + return last +} + +func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) { + for idx, item := range arr { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists && idx != preserveIdx { + delete(obj, "cache_control") + *excess-- + } + } +} + +func stripAllCacheControl(arr []any, excess *int) { + for _, item := range arr { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + delete(obj, "cache_control") + *excess-- + } + } +} + +func stripMessageCacheControl(messages []any, excess *int) { + for _, msg := range messages { + if *excess <= 0 { + return + } + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + delete(obj, "cache_control") + *excess-- + } + } + } +} + // normalizeCacheControlTTL ensures cache_control TTL values don't violate the // prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not // appear after a 5m-TTL block anywhere in the evaluation order. @@ -1344,74 +1520,48 @@ func countCacheControls(payload []byte) int { // Strategy: walk all cache_control blocks in evaluation order. Once a 5m block // is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m). func normalizeCacheControlTTL(payload []byte) []byte { - seen5m := false // once true, all subsequent 1h blocks must be downgraded + root, ok := parsePayloadObject(payload) + if !ok { + return payload + } - // Phase 1: tools (evaluated first) - tools := gjson.GetBytes(payload, "tools") - if tools.IsArray() { - idx := 0 - tools.ForEach(func(_, tool gjson.Result) bool { - cc := tool.Get("cache_control") - if cc.Exists() { - ttl := cc.Get("ttl").String() - if ttl != "1h" { - seen5m = true - } else if seen5m { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control.ttl", idx)) + seen5m := false + + if tools, ok := asArray(root["tools"]); ok { + for _, tool := range tools { + if obj, ok := asObject(tool); ok { + normalizeTTLForBlock(obj, &seen5m) + } + } + } + + if system, ok := asArray(root["system"]); ok { + for _, item := range system { + if obj, ok := asObject(item); ok { + normalizeTTLForBlock(obj, &seen5m) + } + } + } + + if messages, ok := asArray(root["messages"]); ok { + for _, msg := range messages { + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if obj, ok := asObject(item); ok { + normalizeTTLForBlock(obj, &seen5m) } } - idx++ - return true - }) + } } - // Phase 2: system blocks (evaluated second, in array order) - system := gjson.GetBytes(payload, "system") - if system.IsArray() { - idx := 0 - system.ForEach(func(_, item gjson.Result) bool { - cc := item.Get("cache_control") - if cc.Exists() { - ttl := cc.Get("ttl").String() - if ttl != "1h" { - seen5m = true - } else if seen5m { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control.ttl", idx)) - } - } - idx++ - return true - }) - } - - // Phase 3: message content blocks (evaluated last, in array order) - messages := gjson.GetBytes(payload, "messages") - if messages.IsArray() { - msgIdx := 0 - messages.ForEach(func(_, msg gjson.Result) bool { - content := msg.Get("content") - if content.IsArray() { - contentIdx := 0 - content.ForEach(func(_, item gjson.Result) bool { - cc := item.Get("cache_control") - if cc.Exists() { - ttl := cc.Get("ttl").String() - if ttl != "1h" { - seen5m = true - } else if seen5m { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control.ttl", msgIdx, contentIdx)) - } - } - contentIdx++ - return true - }) - } - msgIdx++ - return true - }) - } - - return payload + return marshalPayloadObject(payload, root) } // enforceCacheControlLimit removes excess cache_control blocks from a payload @@ -1419,156 +1569,76 @@ func normalizeCacheControlTTL(payload []byte) []byte { // // Anthropic evaluates cache breakpoints in order: tools → system → messages. // The most valuable breakpoints are: -// 1. Last tool — caches ALL tool definitions -// 2. Last system block — caches ALL system content -// 3. Recent messages — cache conversation context +// 1. Last tool — caches ALL tool definitions +// 2. Last system block — caches ALL system content +// 3. Recent messages — cache conversation context // // Removal priority (strip lowest-value first): -// Phase 1: system blocks earliest-first, preserving the last one. -// Phase 2: tool blocks earliest-first, preserving the last one. -// Phase 3: message content blocks earliest-first. -// Phase 4: remaining system blocks (last system). -// Phase 5: remaining tool blocks (last tool). +// +// Phase 1: system blocks earliest-first, preserving the last one. +// Phase 2: tool blocks earliest-first, preserving the last one. +// Phase 3: message content blocks earliest-first. +// Phase 4: remaining system blocks (last system). +// Phase 5: remaining tool blocks (last tool). func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte { - total := countCacheControls(payload) + root, ok := parsePayloadObject(payload) + if !ok { + return payload + } + + total := countCacheControlsMap(root) if total <= maxBlocks { return payload } excess := total - maxBlocks - // Phase 1: strip cache_control from system blocks earliest-first, but SKIP the last one. - // The last system cache_control is high-value because it caches all system content. - system := gjson.GetBytes(payload, "system") - if system.IsArray() { - lastSysCCIdx := -1 - sysIdx := 0 - system.ForEach(func(_, item gjson.Result) bool { - if item.Get("cache_control").Exists() { - lastSysCCIdx = sysIdx - } - sysIdx++ - return true - }) + var system []any + if arr, ok := asArray(root["system"]); ok { + system = arr + } + var tools []any + if arr, ok := asArray(root["tools"]); ok { + tools = arr + } + var messages []any + if arr, ok := asArray(root["messages"]); ok { + messages = arr + } - idx := 0 - system.ForEach(func(_, item gjson.Result) bool { - if excess <= 0 { - return false - } - if item.Get("cache_control").Exists() && idx != lastSysCCIdx { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx)) - excess-- - } - idx++ - return true - }) + if len(system) > 0 { + stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess) } if excess <= 0 { - return payload + return marshalPayloadObject(payload, root) } - // Phase 2: strip cache_control from tools earliest-first, but SKIP the last one. - // Only the last tool cache_control is needed to cache all tool definitions. - tools := gjson.GetBytes(payload, "tools") - if tools.IsArray() { - lastToolCCIdx := -1 - toolIdx := 0 - tools.ForEach(func(_, tool gjson.Result) bool { - if tool.Get("cache_control").Exists() { - lastToolCCIdx = toolIdx - } - toolIdx++ - return true - }) - - idx := 0 - tools.ForEach(func(_, tool gjson.Result) bool { - if excess <= 0 { - return false - } - if tool.Get("cache_control").Exists() && idx != lastToolCCIdx { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx)) - excess-- - } - idx++ - return true - }) + if len(tools) > 0 { + stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess) } if excess <= 0 { - return payload + return marshalPayloadObject(payload, root) } - // Phase 3: strip cache_control from message content blocks, earliest first. - // Older conversation turns are least likely to help immediate reuse. - messages := gjson.GetBytes(payload, "messages") - if messages.IsArray() { - msgIdx := 0 - messages.ForEach(func(_, msg gjson.Result) bool { - if excess <= 0 { - return false - } - content := msg.Get("content") - if content.IsArray() { - contentIdx := 0 - content.ForEach(func(_, item gjson.Result) bool { - if excess <= 0 { - return false - } - if item.Get("cache_control").Exists() { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, contentIdx)) - excess-- - } - contentIdx++ - return true - }) - } - msgIdx++ - return true - }) + if len(messages) > 0 { + stripMessageCacheControl(messages, &excess) } if excess <= 0 { - return payload + return marshalPayloadObject(payload, root) } - // Phase 4: strip any remaining system cache_control blocks. - system = gjson.GetBytes(payload, "system") - if system.IsArray() { - idx := 0 - system.ForEach(func(_, item gjson.Result) bool { - if excess <= 0 { - return false - } - if item.Get("cache_control").Exists() { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx)) - excess-- - } - idx++ - return true - }) + if len(system) > 0 { + stripAllCacheControl(system, &excess) } if excess <= 0 { - return payload + return marshalPayloadObject(payload, root) } - // Phase 5: strip any remaining tool cache_control blocks (including the last tool). - tools = gjson.GetBytes(payload, "tools") - if tools.IsArray() { - idx := 0 - tools.ForEach(func(_, tool gjson.Result) bool { - if excess <= 0 { - return false - } - if tool.Get("cache_control").Exists() { - payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx)) - excess-- - } - idx++ - return true - }) + if len(tools) > 0 { + stripAllCacheControl(tools, &excess) } - return payload + return marshalPayloadObject(payload, root) } // injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.