From 04ba8c8bc358650c3436bba9e6ab9c832a142fa7 Mon Sep 17 00:00:00 2001 From: CharTyr Date: Sun, 29 Mar 2026 22:23:18 -0400 Subject: [PATCH 1/2] feat(amp): sanitize signatures and handle stream suppression for Amp compatibility --- internal/api/modules/amp/fallback_handlers.go | 10 + internal/api/modules/amp/response_rewriter.go | 315 ++++++++++++++++-- .../api/modules/amp/response_rewriter_test.go | 38 +++ 3 files changed, 328 insertions(+), 35 deletions(-) diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 7d7f7f5f..97dd0c9d 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc return } + // Sanitize request body: remove thinking blocks with invalid signatures + // to prevent upstream API 400 errors + bodyBytes = SanitizeAmpRequestBody(bodyBytes) + // Restore the body for the handler to read c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) @@ -259,10 +263,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) + // Wrap with ResponseRewriter for local providers too, because upstream + // proxies (e.g. NewAPI) may return a different model name and lack + // Amp-required fields like thinking.signature. + rewriter := NewResponseRewriter(c.Writer, modelName) + c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) + rewriter.Flush() } else { // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 715034f1..fa83f7b9 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -2,6 +2,7 @@ package amp import ( "bytes" + "fmt" "net/http" "strings" @@ -12,32 +13,83 @@ import ( ) // ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body -// It's used to rewrite model names in responses when model mapping is used +// It is used to rewrite model names in responses when model mapping is used +// and to keep Amp-compatible response shapes. type ResponseRewriter struct { gin.ResponseWriter - body *bytes.Buffer - originalModel string - isStreaming bool + body *bytes.Buffer + originalModel string + isStreaming bool + suppressedContentBlock map[int]struct{} } -// NewResponseRewriter creates a new response rewriter for model name substitution +// NewResponseRewriter creates a new response rewriter for model name substitution. func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { return &ResponseRewriter{ - ResponseWriter: w, - body: &bytes.Buffer{}, - originalModel: originalModel, + ResponseWriter: w, + body: &bytes.Buffer{}, + originalModel: originalModel, + suppressedContentBlock: make(map[int]struct{}), } } -// Write intercepts response writes and buffers them for model name replacement +const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap + +func looksLikeSSEChunk(data []byte) bool { + return bytes.Contains(data, []byte("data:")) || + bytes.Contains(data, []byte("event:")) || + bytes.Contains(data, []byte("message_start")) || + bytes.Contains(data, []byte("message_delta")) || + bytes.Contains(data, []byte("content_block_start")) || + bytes.Contains(data, []byte("content_block_delta")) || + bytes.Contains(data, []byte("content_block_stop")) || + bytes.Contains(data, []byte("\n\n")) +} + +func (rw *ResponseRewriter) enableStreaming(reason string) error { + if rw.isStreaming { + return nil + } + rw.isStreaming = true + + if rw.body != nil && rw.body.Len() > 0 { + buf := rw.body.Bytes() + toFlush := make([]byte, len(buf)) + copy(toFlush, buf) + rw.body.Reset() + + if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil { + return err + } + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + + log.Debugf("amp response rewriter: switched to streaming (%s)", reason) + return nil +} + func (rw *ResponseRewriter) Write(data []byte) (int, error) { - // Detect streaming on first write - if rw.body.Len() == 0 && !rw.isStreaming { + if !rw.isStreaming && rw.body.Len() == 0 { contentType := rw.Header().Get("Content-Type") rw.isStreaming = strings.Contains(contentType, "text/event-stream") || strings.Contains(contentType, "stream") } + if !rw.isStreaming { + if looksLikeSSEChunk(data) { + if err := rw.enableStreaming("sse heuristic"); err != nil { + return 0, err + } + } else if rw.body.Len()+len(data) > maxBufferedResponseBytes { + log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes) + if err := rw.enableStreaming("buffer limit"); err != nil { + return 0, err + } + } + } + if rw.isStreaming { n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) if err == nil { @@ -50,7 +102,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) { return rw.body.Write(data) } -// Flush writes the buffered response with model names rewritten func (rw *ResponseRewriter) Flush() { if rw.isStreaming { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { @@ -59,26 +110,68 @@ func (rw *ResponseRewriter) Flush() { return } if rw.body.Len() > 0 { - if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil { + rewritten := rw.rewriteModelInResponse(rw.body.Bytes()) + // Update Content-Length to match the rewritten body size, since + // signature injection and model name changes alter the payload length. + rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten))) + if _, err := rw.ResponseWriter.Write(rewritten); err != nil { log.Warnf("amp response rewriter: failed to write rewritten response: %v", err) } } } -// modelFieldPaths lists all JSON paths where model name may appear var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} -// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON -// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility -func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { - // 1. Amp Compatibility: Suppress thinking blocks if tool use is detected - // The Amp client struggles when both thinking and tool_use blocks are present +// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks +// in API responses so that the Amp TUI does not crash on P.signature.length. +func ensureAmpSignature(data []byte) []byte { + for index, block := range gjson.GetBytes(data, "content").Array() { + blockType := block.Get("type").String() + if blockType != "tool_use" && blockType != "thinking" { + continue + } + signaturePath := fmt.Sprintf("content.%d.signature", index) + if gjson.GetBytes(data, signaturePath).Exists() { + continue + } + var err error + data, err = sjson.SetBytes(data, signaturePath, "") + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err) + break + } + } + + contentBlockType := gjson.GetBytes(data, "content_block.type").String() + if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() { + var err error + data, err = sjson.SetBytes(data, "content_block.signature", "") + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err) + } + } + + return data +} + +func (rw *ResponseRewriter) markSuppressedContentBlock(index int) { + if rw.suppressedContentBlock == nil { + rw.suppressedContentBlock = make(map[int]struct{}) + } + rw.suppressedContentBlock[index] = struct{}{} +} + +func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool { + _, ok := rw.suppressedContentBlock[index] + return ok +} + +func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() { filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`) if filtered.Exists() { originalCount := gjson.GetBytes(data, "content.#").Int() filteredCount := filtered.Get("#").Int() - if originalCount > filteredCount { var err error data, err = sjson.SetBytes(data, "content", filtered.Value()) @@ -86,13 +179,41 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err) } else { log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount) - // Log the result for verification - log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String()) } } } } + eventType := gjson.GetBytes(data, "type").String() + indexResult := gjson.GetBytes(data, "index") + if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() { + rw.markSuppressedContentBlock(int(indexResult.Int())) + return nil + } + if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" { + if indexResult.Exists() { + rw.markSuppressedContentBlock(int(indexResult.Int())) + } + return nil + } + if eventType == "content_block_stop" && indexResult.Exists() { + index := int(indexResult.Int()) + if rw.isSuppressedContentBlock(index) { + delete(rw.suppressedContentBlock, index) + return nil + } + } + + return data +} + +func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { + data = ensureAmpSignature(data) + data = rw.suppressAmpThinking(data) + if len(data) == 0 { + return data + } + if rw.originalModel == "" { return data } @@ -104,24 +225,148 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { return data } -// rewriteStreamChunk rewrites model names in SSE stream chunks func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { - if rw.originalModel == "" { - return chunk + lines := bytes.Split(chunk, []byte("\n")) + var out [][]byte + + i := 0 + for i < len(lines) { + line := lines[i] + trimmed := bytes.TrimSpace(line) + + // Case 1: "event:" line - look ahead for its "data:" line + if bytes.HasPrefix(trimmed, []byte("event: ")) { + // Scan forward past blank lines to find the data: line + dataIdx := -1 + for j := i + 1; j < len(lines); j++ { + t := bytes.TrimSpace(lines[j]) + if len(t) == 0 { + continue + } + if bytes.HasPrefix(t, []byte("data: ")) { + dataIdx = j + } + break + } + + if dataIdx >= 0 { + // Found event+data pair - process through model rewriter only + // (no thinking suppression for streaming) + jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + rewritten := rw.rewriteStreamEvent(jsonData) + // Emit event line + out = append(out, line) + // Emit blank lines between event and data + for k := i + 1; k < dataIdx; k++ { + out = append(out, lines[k]) + } + // Emit rewritten data + out = append(out, append([]byte("data: "), rewritten...)) + i = dataIdx + 1 + continue + } + } + + // No data line found (orphan event from cross-chunk split) + // Pass it through as-is - the data will arrive in the next chunk + out = append(out, line) + i++ + continue + } + + // Case 2: standalone "data:" line (no preceding event: in this chunk) + if bytes.HasPrefix(trimmed, []byte("data: ")) { + jsonData := bytes.TrimPrefix(trimmed, []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + rewritten := rw.rewriteStreamEvent(jsonData) + out = append(out, append([]byte("data: "), rewritten...)) + i++ + continue + } + } + + // Case 3: everything else + out = append(out, line) + i++ } - // SSE format: "data: {json}\n\n" - lines := bytes.Split(chunk, []byte("\n")) - for i, line := range lines { - if bytes.HasPrefix(line, []byte("data: ")) { - jsonData := bytes.TrimPrefix(line, []byte("data: ")) - if len(jsonData) > 0 && jsonData[0] == '{' { - // Rewrite JSON in the data line - rewritten := rw.rewriteModelInResponse(jsonData) - lines[i] = append([]byte("data: "), rewritten...) + return bytes.Join(out, []byte("\n")) +} + +// rewriteStreamEvent processes a single JSON event in the SSE stream. +// It rewrites model names and ensures signature fields exist. +// Unlike rewriteModelInResponse, it does NOT suppress thinking blocks +// in streaming mode - they are passed through with signature injection. +func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { + // Inject empty signature where needed + data = ensureAmpSignature(data) + + // Rewrite model name + if rw.originalModel != "" { + for _, path := range modelFieldPaths { + if gjson.GetBytes(data, path).Exists() { + data, _ = sjson.SetBytes(data, path, rw.originalModel) } } } - return bytes.Join(lines, []byte("\n")) + return data +} + +// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures +// from the messages array in a request body before forwarding to the upstream API. +// This prevents 400 errors from the API which requires valid signatures on thinking blocks. +func SanitizeAmpRequestBody(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + + modified := false + for msgIdx, msg := range messages.Array() { + if msg.Get("role").String() != "assistant" { + continue + } + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + + var keepBlocks []interface{} + removedCount := 0 + + for _, block := range content.Array() { + blockType := block.Get("type").String() + if blockType == "thinking" { + sig := block.Get("signature") + if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" { + removedCount++ + continue + } + } + keepBlocks = append(keepBlocks, block.Value()) + } + + if removedCount > 0 { + contentPath := fmt.Sprintf("messages.%d.content", msgIdx) + var err error + if len(keepBlocks) == 0 { + body, err = sjson.SetBytes(body, contentPath, []interface{}{}) + } else { + body, err = sjson.SetBytes(body, contentPath, keepBlocks) + } + if err != nil { + log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err) + continue + } + modified = true + log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx) + } + } + + if modified { + log.Debugf("Amp RequestSanitizer: sanitized request body") + } + return body } diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index 114a9516..ca477d4e 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -100,6 +100,44 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) { } } +func TestRewriteStreamChunk_SuppressesThinkingContentBlockFrames(t *testing.T) { + rw := &ResponseRewriter{} + + chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n") + result := rw.rewriteStreamChunk(chunk) + + if contains(result, []byte("\"thinking\"")) || contains(result, []byte("\"thinking_delta\"")) { + t.Fatalf("expected thinking content_block frames to be suppressed, got %s", string(result)) + } + if contains(result, []byte("content_block_stop")) { + t.Fatalf("expected suppressed thinking content_block_stop to be removed, got %s", string(result)) + } + if !contains(result, []byte("\"tool_use\"")) { + t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result)) + } + if !contains(result, []byte("\"signature\":\"\"")) { + t.Fatalf("expected tool_use content_block signature injection, got %s", string(result)) + } +} + +func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte("drop-whitespace")) { + t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result)) + } + if contains(result, []byte("drop-number")) { + t.Fatalf("expected non-string signature block to be removed, got %s", string(result)) + } + if !contains(result, []byte("keep-valid")) { + t.Fatalf("expected valid thinking block to remain, got %s", string(result)) + } + if !contains(result, []byte("keep-text")) { + t.Fatalf("expected non-thinking content to remain, got %s", string(result)) + } +} + func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) { From b15453c369897df02b016d1dbb2d879fe9c1c68c Mon Sep 17 00:00:00 2001 From: CharTyr Date: Mon, 30 Mar 2026 00:42:04 -0400 Subject: [PATCH 2/2] fix(amp): address PR review - stream thinking suppression, SSE detection, test init - Call suppressAmpThinking in rewriteStreamEvent for streaming path - Handle nil return from suppressAmpThinking to skip suppressed events - Narrow looksLikeSSEChunk to line-prefix detection (HasPrefix vs Contains) - Initialize suppressedContentBlock map in test --- internal/api/modules/amp/response_rewriter.go | 36 ++++++++++++------- .../api/modules/amp/response_rewriter_test.go | 2 +- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index fa83f7b9..64757963 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -36,14 +36,14 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap func looksLikeSSEChunk(data []byte) bool { - return bytes.Contains(data, []byte("data:")) || - bytes.Contains(data, []byte("event:")) || - bytes.Contains(data, []byte("message_start")) || - bytes.Contains(data, []byte("message_delta")) || - bytes.Contains(data, []byte("content_block_start")) || - bytes.Contains(data, []byte("content_block_delta")) || - bytes.Contains(data, []byte("content_block_stop")) || - bytes.Contains(data, []byte("\n\n")) + for _, line := range bytes.Split(data, []byte("\n")) { + trimmed := bytes.TrimSpace(line) + if bytes.HasPrefix(trimmed, []byte("data:")) || + bytes.HasPrefix(trimmed, []byte("event:")) { + return true + } + } + return false } func (rw *ResponseRewriter) enableStreaming(reason string) error { @@ -250,11 +250,15 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { } if dataIdx >= 0 { - // Found event+data pair - process through model rewriter only - // (no thinking suppression for streaming) + // Found event+data pair - process through rewriter jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: ")) if len(jsonData) > 0 && jsonData[0] == '{' { rewritten := rw.rewriteStreamEvent(jsonData) + if rewritten == nil { + // Event suppressed (e.g. thinking block), skip event+data pair + i = dataIdx + 1 + continue + } // Emit event line out = append(out, line) // Emit blank lines between event and data @@ -280,7 +284,9 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { jsonData := bytes.TrimPrefix(trimmed, []byte("data: ")) if len(jsonData) > 0 && jsonData[0] == '{' { rewritten := rw.rewriteStreamEvent(jsonData) - out = append(out, append([]byte("data: "), rewritten...)) + if rewritten != nil { + out = append(out, append([]byte("data: "), rewritten...)) + } i++ continue } @@ -296,9 +302,13 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { // rewriteStreamEvent processes a single JSON event in the SSE stream. // It rewrites model names and ensures signature fields exist. -// Unlike rewriteModelInResponse, it does NOT suppress thinking blocks -// in streaming mode - they are passed through with signature injection. func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { + // Suppress thinking blocks before any other processing. + data = rw.suppressAmpThinking(data) + if len(data) == 0 { + return nil + } + // Inject empty signature where needed data = ensureAmpSignature(data) diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index ca477d4e..2f23d74d 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -101,7 +101,7 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) { } func TestRewriteStreamChunk_SuppressesThinkingContentBlockFrames(t *testing.T) { - rw := &ResponseRewriter{} + rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})} chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n") result := rw.rewriteStreamChunk(chunk)