Merge pull request #2403 from CharTyr/clean-pr

fix(amp): 修复Amp CLI 集成 缺失/无效 signature 导致的 TUI 崩溃与上游 400 问题
This commit is contained in:
Luis Pater
2026-03-30 12:54:15 +08:00
committed by GitHub
3 changed files with 338 additions and 35 deletions

View File

@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
return
}
// Sanitize request body: remove thinking blocks with invalid signatures
// to prevent upstream API 400 errors
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
// Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
@@ -259,10 +263,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Wrap with ResponseRewriter for local providers too, because upstream
// proxies (e.g. NewAPI) may return a different model name and lack
// Amp-required fields like thinking.signature.
rewriter := NewResponseRewriter(c.Writer, modelName)
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
rewriter.Flush()
} else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))

View File

@@ -2,6 +2,7 @@ package amp
import (
"bytes"
"fmt"
"net/http"
"strings"
@@ -12,32 +13,83 @@ import (
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
// It is used to rewrite model names in responses when model mapping is used
// and to keep Amp-compatible response shapes.
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
body *bytes.Buffer
originalModel string
isStreaming bool
suppressedContentBlock map[int]struct{}
}
// NewResponseRewriter creates a new response rewriter for model name substitution
// NewResponseRewriter creates a new response rewriter for model name substitution.
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
suppressedContentBlock: make(map[int]struct{}),
}
}
// Write intercepts response writes and buffers them for model name replacement
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool {
for _, line := range bytes.Split(data, []byte("\n")) {
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) ||
bytes.HasPrefix(trimmed, []byte("event:")) {
return true
}
}
return false
}
func (rw *ResponseRewriter) enableStreaming(reason string) error {
if rw.isStreaming {
return nil
}
rw.isStreaming = true
if rw.body != nil && rw.body.Len() > 0 {
buf := rw.body.Bytes()
toFlush := make([]byte, len(buf))
copy(toFlush, buf)
rw.body.Reset()
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
return err
}
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
return nil
}
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
if !rw.isStreaming && rw.body.Len() == 0 {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if !rw.isStreaming {
if looksLikeSSEChunk(data) {
if err := rw.enableStreaming("sse heuristic"); err != nil {
return 0, err
}
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
if err := rw.enableStreaming("buffer limit"); err != nil {
return 0, err
}
}
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
@@ -50,7 +102,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
@@ -59,26 +110,68 @@ func (rw *ResponseRewriter) Flush() {
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
// Update Content-Length to match the rewritten body size, since
// signature injection and model name changes alter the payload length.
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
for index, block := range gjson.GetBytes(data, "content").Array() {
blockType := block.Get("type").String()
if blockType != "tool_use" && blockType != "thinking" {
continue
}
signaturePath := fmt.Sprintf("content.%d.signature", index)
if gjson.GetBytes(data, signaturePath).Exists() {
continue
}
var err error
data, err = sjson.SetBytes(data, signaturePath, "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
break
}
}
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
var err error
data, err = sjson.SetBytes(data, "content_block.signature", "")
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
}
}
return data
}
func (rw *ResponseRewriter) markSuppressedContentBlock(index int) {
if rw.suppressedContentBlock == nil {
rw.suppressedContentBlock = make(map[int]struct{})
}
rw.suppressedContentBlock[index] = struct{}{}
}
func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool {
_, ok := rw.suppressedContentBlock[index]
return ok
}
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
@@ -86,13 +179,41 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
eventType := gjson.GetBytes(data, "type").String()
indexResult := gjson.GetBytes(data, "index")
if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
return nil
}
if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" {
if indexResult.Exists() {
rw.markSuppressedContentBlock(int(indexResult.Int()))
}
return nil
}
if eventType == "content_block_stop" && indexResult.Exists() {
index := int(indexResult.Int())
if rw.isSuppressedContentBlock(index) {
delete(rw.suppressedContentBlock, index)
return nil
}
}
return data
}
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
}
if rw.originalModel == "" {
return data
}
@@ -104,24 +225,158 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
lines := bytes.Split(chunk, []byte("\n"))
var out [][]byte
i := 0
for i < len(lines) {
line := lines[i]
trimmed := bytes.TrimSpace(line)
// Case 1: "event:" line - look ahead for its "data:" line
if bytes.HasPrefix(trimmed, []byte("event: ")) {
// Scan forward past blank lines to find the data: line
dataIdx := -1
for j := i + 1; j < len(lines); j++ {
t := bytes.TrimSpace(lines[j])
if len(t) == 0 {
continue
}
if bytes.HasPrefix(t, []byte("data: ")) {
dataIdx = j
}
break
}
if dataIdx >= 0 {
// Found event+data pair - process through rewriter
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
if rewritten == nil {
// Event suppressed (e.g. thinking block), skip event+data pair
i = dataIdx + 1
continue
}
// Emit event line
out = append(out, line)
// Emit blank lines between event and data
for k := i + 1; k < dataIdx; k++ {
out = append(out, lines[k])
}
// Emit rewritten data
out = append(out, append([]byte("data: "), rewritten...))
i = dataIdx + 1
continue
}
}
// No data line found (orphan event from cross-chunk split)
// Pass it through as-is - the data will arrive in the next chunk
out = append(out, line)
i++
continue
}
// Case 2: standalone "data:" line (no preceding event: in this chunk)
if bytes.HasPrefix(trimmed, []byte("data: ")) {
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
rewritten := rw.rewriteStreamEvent(jsonData)
if rewritten != nil {
out = append(out, append([]byte("data: "), rewritten...))
}
i++
continue
}
}
// Case 3: everything else
out = append(out, line)
i++
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
return bytes.Join(out, []byte("\n"))
}
// rewriteStreamEvent processes a single JSON event in the SSE stream.
// It rewrites model names and ensures signature fields exist.
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Suppress thinking blocks before any other processing.
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return nil
}
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
}
return bytes.Join(lines, []byte("\n"))
return data
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// from the messages array in a request body before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
modified := false
for msgIdx, msg := range messages.Array() {
if msg.Get("role").String() != "assistant" {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
var keepBlocks []interface{}
removedCount := 0
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
removedCount++
continue
}
}
keepBlocks = append(keepBlocks, block.Value())
}
if removedCount > 0 {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
} else {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
continue
}
modified = true
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
}
}
if modified {
log.Debugf("Amp RequestSanitizer: sanitized request body")
}
return body
}

View File

@@ -100,6 +100,44 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
}
}
func TestRewriteStreamChunk_SuppressesThinkingContentBlockFrames(t *testing.T) {
rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})}
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
result := rw.rewriteStreamChunk(chunk)
if contains(result, []byte("\"thinking\"")) || contains(result, []byte("\"thinking_delta\"")) {
t.Fatalf("expected thinking content_block frames to be suppressed, got %s", string(result))
}
if contains(result, []byte("content_block_stop")) {
t.Fatalf("expected suppressed thinking content_block_stop to be removed, got %s", string(result))
}
if !contains(result, []byte("\"tool_use\"")) {
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
}
if !contains(result, []byte("\"signature\":\"\"")) {
t.Fatalf("expected tool_use content_block signature injection, got %s", string(result))
}
}
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-whitespace")) {
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
}
if contains(result, []byte("drop-number")) {
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
}
if !contains(result, []byte("keep-valid")) {
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
}
if !contains(result, []byte("keep-text")) {
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {