diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 97dd0c9d..e4e0f8a6 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -253,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter.suppressThinking = true c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) @@ -267,6 +268,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // proxies (e.g. NewAPI) may return a different model name and lack // Amp-required fields like thinking.signature. rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter.suppressThinking = providerName != "claude" c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index a38195d7..324e893b 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -17,19 +17,18 @@ import ( // and to keep Amp-compatible response shapes. type ResponseRewriter struct { gin.ResponseWriter - body *bytes.Buffer - originalModel string - isStreaming bool - suppressedContentBlock map[int]struct{} + body *bytes.Buffer + originalModel string + isStreaming bool + suppressThinking bool } // NewResponseRewriter creates a new response rewriter for model name substitution. func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { return &ResponseRewriter{ - ResponseWriter: w, - body: &bytes.Buffer{}, - originalModel: originalModel, - suppressedContentBlock: make(map[int]struct{}), + ResponseWriter: w, + body: &bytes.Buffer{}, + originalModel: originalModel, } } @@ -99,7 +98,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) { } if rw.isStreaming { - n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) + rewritten := rw.rewriteStreamChunk(data) + n, err := rw.ResponseWriter.Write(rewritten) if err == nil { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -162,19 +162,10 @@ func ensureAmpSignature(data []byte) []byte { return data } -func (rw *ResponseRewriter) markSuppressedContentBlock(index int) { - if rw.suppressedContentBlock == nil { - rw.suppressedContentBlock = make(map[int]struct{}) - } - rw.suppressedContentBlock[index] = struct{}{} -} - -func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool { - _, ok := rw.suppressedContentBlock[index] - return ok -} - func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { + if !rw.suppressThinking { + return data + } if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() { filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`) if filtered.Exists() { @@ -185,33 +176,11 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { data, err = sjson.SetBytes(data, "content", filtered.Value()) if err != nil { log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err) - } else { - log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount) } } } } - eventType := gjson.GetBytes(data, "type").String() - indexResult := gjson.GetBytes(data, "index") - if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() { - rw.markSuppressedContentBlock(int(indexResult.Int())) - return nil - } - if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" { - if indexResult.Exists() { - rw.markSuppressedContentBlock(int(indexResult.Int())) - } - return nil - } - if eventType == "content_block_stop" && indexResult.Exists() { - index := int(indexResult.Int()) - if rw.isSuppressedContentBlock(index) { - delete(rw.suppressedContentBlock, index) - return nil - } - } - return data } @@ -262,6 +231,10 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: ")) if len(jsonData) > 0 && jsonData[0] == '{' { rewritten := rw.rewriteStreamEvent(jsonData) + if rewritten == nil { + i = dataIdx + 1 + continue + } // Emit event line out = append(out, line) // Emit blank lines between event and data @@ -287,7 +260,9 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { jsonData := bytes.TrimPrefix(trimmed, []byte("data: ")) if len(jsonData) > 0 && jsonData[0] == '{' { rewritten := rw.rewriteStreamEvent(jsonData) - out = append(out, append([]byte("data: "), rewritten...)) + if rewritten != nil { + out = append(out, append([]byte("data: "), rewritten...)) + } i++ continue } diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index 50712cf9..31ba56bd 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -102,7 +102,7 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) { } func TestRewriteStreamChunk_PreservesThinkingWithSignatureInjection(t *testing.T) { - rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})} + rw := &ResponseRewriter{} chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n") result := rw.rewriteStreamChunk(chunk) diff --git a/internal/api/server.go b/internal/api/server.go index 95327eef..10f88931 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -323,6 +323,10 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk // setupRoutes configures the API routes for the server. // It defines the endpoints and associates them with their respective handlers. func (s *Server) setupRoutes() { + s.engine.GET("/healthz", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + s.engine.GET("/management.html", s.serveManagementControlPanel) openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers) geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers) diff --git a/internal/api/server_test.go b/internal/api/server_test.go index 7ce38b8f..dbc2cd5a 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1,6 +1,7 @@ package api import ( + "encoding/json" "net/http" "net/http/httptest" "os" @@ -46,6 +47,28 @@ func newTestServer(t *testing.T) *Server { return NewServer(cfg, authManager, accessManager, configPath) } +func TestHealthz(t *testing.T) { + server := newTestServer(t) + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rr := httptest.NewRecorder() + server.engine.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("unexpected status code: got %d want %d; body=%s", rr.Code, http.StatusOK, rr.Body.String()) + } + + var resp struct { + Status string `json:"status"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response JSON: %v; body=%s", err, rr.Body.String()) + } + if resp.Status != "ok" { + t.Fatalf("unexpected response status: got %q want %q", resp.Status, "ok") + } +} + func TestAmpProviderModelRoutes(t *testing.T) { testCases := []struct { name string diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 4a52d4a8..e230f5fd 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -6,7 +6,7 @@ package claude import ( - "bytes" + "fmt" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" @@ -31,8 +31,6 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator" // - []byte: The transformed request in Gemini CLI format. func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := inputRawJSON - rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) - // Build output Gemini CLI request JSON out := []byte(`{"contents":[]}`) out, _ = sjson.SetBytes(out, "model", modelName) @@ -146,13 +144,37 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) }) } + // strip trailing model turn with unanswered function calls — + // Gemini returns empty responses when the last turn is a model + // functionCall with no corresponding user functionResponse. + contents := gjson.GetBytes(out, "contents") + if contents.Exists() && contents.IsArray() { + arr := contents.Array() + if len(arr) > 0 { + last := arr[len(arr)-1] + if last.Get("role").String() == "model" { + hasFC := false + last.Get("parts").ForEach(func(_, part gjson.Result) bool { + if part.Get("functionCall").Exists() { + hasFC = true + return false + } + return true + }) + if hasFC { + out, _ = sjson.DeleteBytes(out, fmt.Sprintf("contents.%d", len(arr)-1)) + } + } + } + } + // tools if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() { hasTools := false toolsResult.ForEach(func(_, toolResult gjson.Result) bool { inputSchemaResult := toolResult.Get("input_schema") if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { - inputSchema := inputSchemaResult.Raw + inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw) tool := []byte(toolResult.Raw) var err error tool, err = sjson.DeleteBytes(tool, "input_schema") @@ -168,6 +190,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) tool, _ = sjson.DeleteBytes(tool, "type") tool, _ = sjson.DeleteBytes(tool, "cache_control") tool, _ = sjson.DeleteBytes(tool, "defer_loading") + tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming") tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String())) if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() { if !hasTools {