From f033d3a6df8ebd94a8c4d73ff7e0641b92f45164 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 29 Mar 2026 13:00:43 +0800 Subject: [PATCH] fix(claude): enhance ensureModelMaxTokens to use registered max_completion_tokens and fallback to default --- internal/runtime/executor/claude_executor.go | 46 ++++++----- .../runtime/executor/claude_executor_test.go | 78 +++++++++++++++++++ 2 files changed, 103 insertions(+), 21 deletions(-) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index bc5d2065..cc88dd77 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -45,33 +45,14 @@ type ClaudeExecutor struct { // Previously "proxy_" was used but this is a detectable fingerprint difference. const claudeToolPrefix = "" -// Anthropic-compatible upstreams may reject or even crash when dynamically -// registered Claude models omit max_tokens. Use a conservative default. +// Anthropic-compatible upstreams may reject or even crash when Claude models +// omit max_tokens. Prefer registered model metadata before using a fallback. const defaultModelMaxTokens = 1024 func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} } func (e *ClaudeExecutor) Identifier() string { return "claude" } -func ensureModelMaxTokens(body []byte, modelID string) []byte { - if len(body) == 0 || !gjson.ValidBytes(body) { - return body - } - - if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() { - return body - } - - for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) { - if strings.EqualFold(provider, "claude") { - body, _ = sjson.SetBytes(body, "max_tokens", defaultModelMaxTokens) - return body - } - } - - return body -} - // PrepareRequest injects Claude credentials into the outgoing HTTP request. func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { if req == nil { @@ -1906,3 +1887,26 @@ func injectSystemCacheControl(payload []byte) []byte { return payload } + +func ensureModelMaxTokens(body []byte, modelID string) []byte { + if len(body) == 0 || !gjson.ValidBytes(body) { + return body + } + + if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() { + return body + } + + for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) { + if strings.EqualFold(provider, "claude") { + maxTokens := defaultModelMaxTokens + if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 { + maxTokens = info.MaxCompletionTokens + } + body, _ = sjson.SetBytes(body, "max_tokens", maxTokens) + return body + } + } + + return body +} diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index c163d7ea..ee8e9025 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -15,6 +15,7 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -1183,6 +1184,83 @@ func testClaudeExecutorInvalidCompressedErrorBody( } } +func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) { + reg := registry.GetGlobalRegistry() + clientID := "test-claude-max-completion-tokens-client" + modelID := "test-claude-max-completion-tokens-model" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ + ID: modelID, + Type: "claude", + OwnedBy: "anthropic", + Object: "model", + Created: time.Now().Unix(), + MaxCompletionTokens: 4096, + UserDefined: true, + }}) + defer reg.UnregisterClient(clientID) + + input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`) + out := ensureModelMaxTokens(input, modelID) + + if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 { + t.Fatalf("max_tokens = %d, want %d", got, 4096) + } +} + +func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) { + reg := registry.GetGlobalRegistry() + clientID := "test-claude-default-max-tokens-client" + modelID := "test-claude-default-max-tokens-model" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ + ID: modelID, + Type: "claude", + OwnedBy: "anthropic", + Object: "model", + Created: time.Now().Unix(), + UserDefined: true, + }}) + defer reg.UnregisterClient(clientID) + + input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`) + out := ensureModelMaxTokens(input, modelID) + + if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens { + t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens) + } +} + +func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) { + reg := registry.GetGlobalRegistry() + clientID := "test-claude-preserve-max-tokens-client" + modelID := "test-claude-preserve-max-tokens-model" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ + ID: modelID, + Type: "claude", + OwnedBy: "anthropic", + Object: "model", + Created: time.Now().Unix(), + MaxCompletionTokens: 4096, + UserDefined: true, + }}) + defer reg.UnregisterClient(clientID) + + input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`) + out := ensureModelMaxTokens(input, modelID) + + if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 { + t.Fatalf("max_tokens = %d, want %d", got, 2048) + } +} + +func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) { + input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`) + out := ensureModelMaxTokens(input, "test-claude-unregistered-model") + + if gjson.GetBytes(out, "max_tokens").Exists() { + t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw) + } +} + // TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming // requests use Accept-Encoding: identity so the upstream cannot respond with a // compressed SSE body that would silently break the line scanner.