diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index fcb3a9c9..8826b061 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -135,6 +135,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r body = ensureCacheControl(body) } + // Enforce Anthropic's cache_control block limit (max 4 breakpoints per request). + // Cloaking and ensureCacheControl may push the total over 4 when the client + // (e.g. Amp CLI) already sends multiple cache_control blocks. + body = enforceCacheControlLimit(body, 4) + + // Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05. + // A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages). + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -176,11 +185,18 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API) + errBody := httpResp.Body + if ce := httpResp.Header.Get("Content-Encoding"); ce != "" { + if decoded, decErr := decodeResponseBody(httpResp.Body, ce); decErr == nil { + errBody = decoded + } + } + b, _ := io.ReadAll(errBody) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) err = statusErr{code: httpResp.StatusCode, msg: string(b)} - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } return resp, err @@ -276,6 +292,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A body = ensureCacheControl(body) } + // Enforce Anthropic's cache_control block limit (max 4 breakpoints per request). + body = enforceCacheControlLimit(body, 4) + + // Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05. + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -317,10 +339,17 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API) + errBody := httpResp.Body + if ce := httpResp.Header.Get("Content-Encoding"); ce != "" { + if decoded, decErr := decodeResponseBody(httpResp.Body, ce); decErr == nil { + errBody = decoded + } + } + b, _ := io.ReadAll(errBody) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } err = statusErr{code: httpResp.StatusCode, msg: string(b)} @@ -425,6 +454,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut body = checkSystemInstructions(body) } + // Keep count_tokens requests compatible with Anthropic cache-control constraints too. + body = enforceCacheControlLimit(body, 4) + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header (for count_tokens too) var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -464,9 +497,16 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - b, _ := io.ReadAll(resp.Body) + // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API) + errBody := io.ReadCloser(resp.Body) + if ce := resp.Header.Get("Content-Encoding"); ce != "" { + if decoded, decErr := decodeResponseBody(resp.Body, ce); decErr == nil { + errBody = decoded + } + } + b, _ := io.ReadAll(errBody) appendAPIResponseChunk(ctx, e.cfg, b) - if errClose := resp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)} @@ -1083,7 +1123,12 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { billingText := generateBillingHeader(payload) billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText) - agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}` + // No cache_control on the agent block. It is a cloaking artifact with zero cache + // value (the last system block is what actually triggers caching of all system content). + // Including any cache_control here creates an intra-system TTL ordering violation + // when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta + // forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m). + agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}` if strictMode { // Strict mode: billing header + agent identifier only @@ -1103,11 +1148,12 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { if system.IsArray() { system.ForEach(func(_, part gjson.Result) bool { if part.Get("type").String() == "text" { - // Add cache_control with ttl to user system messages if not present + // Add cache_control to user system messages if not present. + // Do NOT add ttl — let it inherit the default (5m) to avoid + // TTL ordering violations with the prompt-caching-scope-2026-01-05 beta. partJSON := part.Raw if !part.Get("cache_control").Exists() { partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral") - partJSON, _ = sjson.Set(partJSON, "cache_control.ttl", "1h") } result += "," + partJSON } @@ -1254,6 +1300,245 @@ func countCacheControls(payload []byte) int { return count } +// normalizeCacheControlTTL ensures cache_control TTL values don't violate the +// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not +// appear after a 5m-TTL block anywhere in the evaluation order. +// +// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages. +// Within each section, blocks are evaluated in array order. A 5m (default) block +// followed by a 1h block at ANY later position is an error — including within +// the same section (e.g. system[1]=5m then system[3]=1h). +// +// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block +// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m). +func normalizeCacheControlTTL(payload []byte) []byte { + seen5m := false // once true, all subsequent 1h blocks must be downgraded + + // Phase 1: tools (evaluated first) + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + idx := 0 + tools.ForEach(func(_, tool gjson.Result) bool { + cc := tool.Get("cache_control") + if cc.Exists() { + ttl := cc.Get("ttl").String() + if ttl != "1h" { + seen5m = true + } else if seen5m { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control.ttl", idx)) + } + } + idx++ + return true + }) + } + + // Phase 2: system blocks (evaluated second, in array order) + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + idx := 0 + system.ForEach(func(_, item gjson.Result) bool { + cc := item.Get("cache_control") + if cc.Exists() { + ttl := cc.Get("ttl").String() + if ttl != "1h" { + seen5m = true + } else if seen5m { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control.ttl", idx)) + } + } + idx++ + return true + }) + } + + // Phase 3: message content blocks (evaluated last, in array order) + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + msgIdx := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + contentIdx := 0 + content.ForEach(func(_, item gjson.Result) bool { + cc := item.Get("cache_control") + if cc.Exists() { + ttl := cc.Get("ttl").String() + if ttl != "1h" { + seen5m = true + } else if seen5m { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control.ttl", msgIdx, contentIdx)) + } + } + contentIdx++ + return true + }) + } + msgIdx++ + return true + }) + } + + return payload +} + +// enforceCacheControlLimit removes excess cache_control blocks from a payload +// so the total does not exceed the Anthropic API limit (currently 4). +// +// Anthropic evaluates cache breakpoints in order: tools → system → messages. +// The most valuable breakpoints are: +// 1. Last tool — caches ALL tool definitions +// 2. Last system block — caches ALL system content +// 3. Recent messages — cache conversation context +// +// Removal priority (strip lowest-value first): +// Phase 1: system blocks earliest-first, preserving the last one. +// Phase 2: tool blocks earliest-first, preserving the last one. +// Phase 3: message content blocks earliest-first. +// Phase 4: remaining system blocks (last system). +// Phase 5: remaining tool blocks (last tool). +func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte { + total := countCacheControls(payload) + if total <= maxBlocks { + return payload + } + + excess := total - maxBlocks + + // Phase 1: strip cache_control from system blocks earliest-first, but SKIP the last one. + // The last system cache_control is high-value because it caches all system content. + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + lastSysCCIdx := -1 + sysIdx := 0 + system.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + lastSysCCIdx = sysIdx + } + sysIdx++ + return true + }) + + idx := 0 + system.ForEach(func(_, item gjson.Result) bool { + if excess <= 0 { + return false + } + if item.Get("cache_control").Exists() && idx != lastSysCCIdx { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx)) + excess-- + } + idx++ + return true + }) + } + if excess <= 0 { + return payload + } + + // Phase 2: strip cache_control from tools earliest-first, but SKIP the last one. + // Only the last tool cache_control is needed to cache all tool definitions. + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + lastToolCCIdx := -1 + toolIdx := 0 + tools.ForEach(func(_, tool gjson.Result) bool { + if tool.Get("cache_control").Exists() { + lastToolCCIdx = toolIdx + } + toolIdx++ + return true + }) + + idx := 0 + tools.ForEach(func(_, tool gjson.Result) bool { + if excess <= 0 { + return false + } + if tool.Get("cache_control").Exists() && idx != lastToolCCIdx { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx)) + excess-- + } + idx++ + return true + }) + } + if excess <= 0 { + return payload + } + + // Phase 3: strip cache_control from message content blocks, earliest first. + // Older conversation turns are least likely to help immediate reuse. + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + msgIdx := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + if excess <= 0 { + return false + } + content := msg.Get("content") + if content.IsArray() { + contentIdx := 0 + content.ForEach(func(_, item gjson.Result) bool { + if excess <= 0 { + return false + } + if item.Get("cache_control").Exists() { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("messages.%d.content.%d.cache_control", msgIdx, contentIdx)) + excess-- + } + contentIdx++ + return true + }) + } + msgIdx++ + return true + }) + } + if excess <= 0 { + return payload + } + + // Phase 4: strip any remaining system cache_control blocks. + system = gjson.GetBytes(payload, "system") + if system.IsArray() { + idx := 0 + system.ForEach(func(_, item gjson.Result) bool { + if excess <= 0 { + return false + } + if item.Get("cache_control").Exists() { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("system.%d.cache_control", idx)) + excess-- + } + idx++ + return true + }) + } + if excess <= 0 { + return payload + } + + // Phase 5: strip any remaining tool cache_control blocks (including the last tool). + tools = gjson.GetBytes(payload, "tools") + if tools.IsArray() { + idx := 0 + tools.ForEach(func(_, tool gjson.Result) bool { + if excess <= 0 { + return false + } + if tool.Get("cache_control").Exists() { + payload, _ = sjson.DeleteBytes(payload, fmt.Sprintf("tools.%d.cache_control", idx)) + excess-- + } + idx++ + return true + }) + } + + return payload +} + // injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching. // Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache." // This enables caching of conversation history, which is especially beneficial for long multi-turn conversations. diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index dd29ed8a..d90076b6 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -348,3 +348,174 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) { t.Fatalf("built-in tool_reference should not be prefixed, got %q", got) } } + +func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) { + payload := []byte(`{ + "tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}], + "system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}], + "messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}] + }`) + + out := normalizeCacheControlTTL(payload) + + if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" { + t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h") + } + if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() { + t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block") + } +} + +func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}} + ], + "system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}], + "messages": [ + {"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]}, + {"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]} + ] + }`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed first (non-last tool)") + } + if !gjson.GetBytes(out, "tools.1.cache_control").Exists() { + t.Fatalf("tools.1.cache_control (last tool) should be preserved") + } + if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() { + t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough") + } +} + +func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) { + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}}, + {"name":"t3","cache_control":{"type":"ephemeral"}}, + {"name":"t4","cache_control":{"type":"ephemeral"}}, + {"name":"t5","cache_control":{"type":"ephemeral"}} + ] + }`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed to satisfy max=4") + } + if !gjson.GetBytes(out, "tools.4.cache_control").Exists() { + t.Fatalf("last tool cache_control should be preserved when possible") + } +} + +func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"input_tokens":42}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}} + ], + "system": [ + {"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}}, + {"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}} + ], + "messages": [ + {"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}, + {"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]} + ] + }`) + + _, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-haiku-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + + if len(seenBody) == 0 { + t.Fatal("expected count_tokens request body to be captured") + } + if got := countCacheControls(seenBody); got > 4 { + t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got) + } + if hasTTLOrderingViolation(seenBody) { + t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody)) + } +} + +func hasTTLOrderingViolation(payload []byte) bool { + seen5m := false + violates := false + + checkCC := func(cc gjson.Result) { + if !cc.Exists() || violates { + return + } + ttl := cc.Get("ttl").String() + if ttl != "1h" { + seen5m = true + return + } + if seen5m { + violates = true + } + } + + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(_, tool gjson.Result) bool { + checkCC(tool.Get("cache_control")) + return !violates + }) + } + + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(_, item gjson.Result) bool { + checkCC(item.Get("cache_control")) + return !violates + }) + } + + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + content.ForEach(func(_, item gjson.Result) bool { + checkCC(item.Get("cache_control")) + return !violates + }) + } + return !violates + }) + } + + return violates +}