Fix Claude cache-control guardrails and gzip error decoding

This commit is contained in:
edlsh
2026-02-28 22:32:33 -05:00
parent 1ae994b4aa
commit 444a47ae63
2 changed files with 465 additions and 9 deletions

View File

@@ -135,6 +135,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
// Cloaking and ensureCacheControl may push the total over 4 when the client
// (e.g. Amp CLI) already sends multiple cache_control blocks.
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -176,11 +185,18 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API)
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
if decoded, decErr := decodeResponseBody(httpResp.Body, ce); decErr == nil {
errBody = decoded
}
}
b, _ := io.ReadAll(errBody)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return resp, err
@@ -276,6 +292,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -317,10 +339,17 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API)
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
if decoded, decErr := decodeResponseBody(httpResp.Body, ce); decErr == nil {
errBody = decoded
}
}
b, _ := io.ReadAll(errBody)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
@@ -425,6 +454,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
body = checkSystemInstructions(body)
}
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
body = enforceCacheControlLimit(body, 4)
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
@@ -464,9 +497,16 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API)
errBody := io.ReadCloser(resp.Body)
if ce := resp.Header.Get("Content-Encoding"); ce != "" {
if decoded, decErr := decodeResponseBody(resp.Body, ce); decErr == nil {
errBody = decoded
}
}
b, _ := io.ReadAll(errBody)
appendAPIResponseChunk(ctx, e.cfg, b)
if errClose := resp.Body.Close(); errClose != nil {
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
@@ -1083,7 +1123,12 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
billingText := generateBillingHeader(payload)
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}`
// No cache_control on the agent block. It is a cloaking artifact with zero cache
// value (the last system block is what actually triggers caching of all system content).
// Including any cache_control here creates an intra-system TTL ordering violation
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
if strictMode {
// Strict mode: billing header + agent identifier only
@@ -1103,11 +1148,12 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
if system.IsArray() {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
// Add cache_control with ttl to user system messages if not present
// Add cache_control to user system messages if not present.
// Do NOT add ttl — let it inherit the default (5m) to avoid
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw
if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
partJSON, _ = sjson.Set(partJSON, "cache_control.ttl", "1h")
}
result += "," + partJSON
}
@@ -1254,6 +1300,245 @@ func countCacheControls(payload []byte) int {
return count
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
//
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
// Within each section, blocks are evaluated in array order. A 5m (default) block
// followed by a 1h block at ANY later position is an error — including within
// the same section (e.g. system[1]=5m then system[3]=1h).
//
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
seen5m := false // once true, all subsequent 1h blocks must be downgraded
// Phase 1: tools (evaluated first)
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
idx := 0
tools.ForEach(func(_, tool gjson.Result) bool {
cc := tool.Get("cache_control")
if cc.Exists() {
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
} else if seen5m {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control.ttl", idx))
}
}
idx++
return true
})
}
// Phase 2: system blocks (evaluated second, in array order)
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
idx := 0
system.ForEach(func(_, item gjson.Result) bool {
cc := item.Get("cache_control")
if cc.Exists() {
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
} else if seen5m {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control.ttl", idx))
}
}
idx++
return true
})
}
// Phase 3: message content blocks (evaluated last, in array order)
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
msgIdx := 0
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
contentIdx := 0
content.ForEach(func(_, item gjson.Result) bool {
cc := item.Get("cache_control")
if cc.Exists() {
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
} else if seen5m {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control.ttl", msgIdx, contentIdx))
}
}
contentIdx++
return true
})
}
msgIdx++
return true
})
}
return payload
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
// so the total does not exceed the Anthropic API limit (currently 4).
//
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
// The most valuable breakpoints are:
// 1. Last tool — caches ALL tool definitions
// 2. Last system block — caches ALL system content
// 3. Recent messages — cache conversation context
//
// Removal priority (strip lowest-value first):
// Phase 1: system blocks earliest-first, preserving the last one.
// Phase 2: tool blocks earliest-first, preserving the last one.
// Phase 3: message content blocks earliest-first.
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
total := countCacheControls(payload)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
// Phase 1: strip cache_control from system blocks earliest-first, but SKIP the last one.
// The last system cache_control is high-value because it caches all system content.
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
lastSysCCIdx := -1
sysIdx := 0
system.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastSysCCIdx = sysIdx
}
sysIdx++
return true
})
idx := 0
system.ForEach(func(_, item gjson.Result) bool {
if excess <= 0 {
return false
}
if item.Get("cache_control").Exists() && idx != lastSysCCIdx {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx))
excess--
}
idx++
return true
})
}
if excess <= 0 {
return payload
}
// Phase 2: strip cache_control from tools earliest-first, but SKIP the last one.
// Only the last tool cache_control is needed to cache all tool definitions.
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
lastToolCCIdx := -1
toolIdx := 0
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("cache_control").Exists() {
lastToolCCIdx = toolIdx
}
toolIdx++
return true
})
idx := 0
tools.ForEach(func(_, tool gjson.Result) bool {
if excess <= 0 {
return false
}
if tool.Get("cache_control").Exists() && idx != lastToolCCIdx {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx))
excess--
}
idx++
return true
})
}
if excess <= 0 {
return payload
}
// Phase 3: strip cache_control from message content blocks, earliest first.
// Older conversation turns are least likely to help immediate reuse.
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
msgIdx := 0
messages.ForEach(func(_, msg gjson.Result) bool {
if excess <= 0 {
return false
}
content := msg.Get("content")
if content.IsArray() {
contentIdx := 0
content.ForEach(func(_, item gjson.Result) bool {
if excess <= 0 {
return false
}
if item.Get("cache_control").Exists() {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, contentIdx))
excess--
}
contentIdx++
return true
})
}
msgIdx++
return true
})
}
if excess <= 0 {
return payload
}
// Phase 4: strip any remaining system cache_control blocks.
system = gjson.GetBytes(payload, "system")
if system.IsArray() {
idx := 0
system.ForEach(func(_, item gjson.Result) bool {
if excess <= 0 {
return false
}
if item.Get("cache_control").Exists() {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx))
excess--
}
idx++
return true
})
}
if excess <= 0 {
return payload
}
// Phase 5: strip any remaining tool cache_control blocks (including the last tool).
tools = gjson.GetBytes(payload, "tools")
if tools.IsArray() {
idx := 0
tools.ForEach(func(_, tool gjson.Result) bool {
if excess <= 0 {
return false
}
if tool.Get("cache_control").Exists() {
payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx))
excess--
}
idx++
return true
})
}
return payload
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.

View File

@@ -348,3 +348,174 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
payload := []byte(`{
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`)
out := normalizeCacheControlTTL(payload)
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
}
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
}
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}},
{"name":"t3","cache_control":{"type":"ephemeral"}},
{"name":"t4","cache_control":{"type":"ephemeral"}},
{"name":"t5","cache_control":{"type":"ephemeral"}}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
}
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
t.Fatalf("last tool cache_control should be preserved when possible")
}
}
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"input_tokens":42}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
]
}`)
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-haiku-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected count_tokens request body to be captured")
}
if got := countCacheControls(seenBody); got > 4 {
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
}
if hasTTLOrderingViolation(seenBody) {
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
}
}
func hasTTLOrderingViolation(payload []byte) bool {
seen5m := false
violates := false
checkCC := func(cc gjson.Result) {
if !cc.Exists() || violates {
return
}
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
return
}
if seen5m {
violates = true
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
checkCC(tool.Get("cache_control"))
return !violates
})
}
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
content.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
return !violates
})
}
return violates
}