diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 390f301d..707fe576 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -2,6 +2,7 @@ package amp import ( "bytes" + "encoding/json" "fmt" "net/http" "strings" @@ -290,8 +291,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { } // SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures -// from the messages array in a request body before forwarding to the upstream API. -// This prevents 400 errors from the API which requires valid signatures on thinking blocks. +// and strips the proxy-injected "signature" field from tool_use blocks in the messages +// array before forwarding to the upstream API. +// This prevents 400 errors from the API which requires valid signatures on thinking +// blocks and does not accept a signature field on tool_use blocks. func SanitizeAmpRequestBody(body []byte) []byte { messages := gjson.GetBytes(body, "messages") if !messages.Exists() || !messages.IsArray() { @@ -309,21 +312,30 @@ func SanitizeAmpRequestBody(body []byte) []byte { } var keepBlocks []interface{} - removedCount := 0 + contentModified := false for _, block := range content.Array() { blockType := block.Get("type").String() if blockType == "thinking" { sig := block.Get("signature") if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" { - removedCount++ + contentModified = true continue } } - keepBlocks = append(keepBlocks, block.Value()) + + // Use raw JSON to prevent float64 rounding of large integers in tool_use inputs + blockRaw := []byte(block.Raw) + if blockType == "tool_use" && block.Get("signature").Exists() { + blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature") + contentModified = true + } + + // sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw + keepBlocks = append(keepBlocks, json.RawMessage(blockRaw)) } - if removedCount > 0 { + if contentModified { contentPath := fmt.Sprintf("messages.%d.content", msgIdx) var err error if len(keepBlocks) == 0 { @@ -332,11 +344,10 @@ func SanitizeAmpRequestBody(body []byte) []byte { body, err = sjson.SetBytes(body, contentPath, keepBlocks) } if err != nil { - log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err) + log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err) continue } modified = true - log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx) } } diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index 31ba56bd..ac95dfc6 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi } } +func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte(`"signature":""`)) { + t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) + } + if !contains(result, []byte(`"valid-sig"`)) { + t.Fatalf("expected thinking signature to remain, got %s", string(result)) + } + if !contains(result, []byte(`"tool_use"`)) { + t.Fatalf("expected tool_use block to remain, got %s", string(result)) + } +} + +func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte("drop-me")) { + t.Fatalf("expected invalid thinking block to be removed, got %s", string(result)) + } + if contains(result, []byte(`"signature"`)) { + t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) + } + if !contains(result, []byte(`"tool_use"`)) { + t.Fatalf("expected tool_use block to remain, got %s", string(result)) + } +} + func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) {