mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-18 23:11:35 +00:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf9b2c49df | ||
|
|
77b42c6165 | ||
|
|
446150a747 | ||
|
|
1cbc4834e1 | ||
|
|
a8a5d03c33 | ||
|
|
76aa917882 | ||
|
|
6ac9b31e4e | ||
|
|
0ad3e8457f | ||
|
|
444a47ae63 | ||
|
|
725f4fdff4 | ||
|
|
c23e46f45d | ||
|
|
b148820c35 | ||
|
|
134f41496d | ||
|
|
c5838dd58d | ||
|
|
b6ca5ef7ce | ||
|
|
1ae994b4aa | ||
|
|
84e9793e61 | ||
|
|
32e64dacfd | ||
|
|
cc1d8f6629 | ||
|
|
5446cd2b02 | ||
|
|
8de0885b7d | ||
|
|
16243f18fd | ||
|
|
a6ce5f36e6 | ||
|
|
e73cf42e28 | ||
|
|
b45343e812 | ||
|
|
8599b1560e | ||
|
|
8bde8c37c0 |
@@ -31,6 +31,7 @@ bin/*
|
||||
.agent/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -44,6 +44,7 @@ GEMINI.md
|
||||
.agents/*
|
||||
.agents/*
|
||||
.opencode/*
|
||||
.idea/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
_bmad-output/*
|
||||
|
||||
@@ -80,6 +80,10 @@ passthrough-headers: false
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum number of different credentials to try for one failed request.
|
||||
# Set to 0 to keep legacy behavior (try all available credentials).
|
||||
max-retry-credentials: 0
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
|
||||
@@ -60,10 +60,8 @@ type ServerOption func(*serverOptionConfig)
|
||||
|
||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||
configDir := filepath.Dir(configPath)
|
||||
if base := util.WritablePath(); base != "" {
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
||||
logsDir := logging.ResolveLogDirectory(cfg)
|
||||
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
|
||||
}
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
@@ -260,7 +258,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
@@ -946,7 +944,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
@@ -109,3 +111,100 @@ func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
t.Setenv("WRITABLE_PATH", "")
|
||||
t.Setenv("writable_path", "")
|
||||
|
||||
originalWD, errGetwd := os.Getwd()
|
||||
if errGetwd != nil {
|
||||
t.Fatalf("failed to get current working directory: %v", errGetwd)
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
if errChdir := os.Chdir(tmpDir); errChdir != nil {
|
||||
t.Fatalf("failed to switch working directory: %v", errChdir)
|
||||
}
|
||||
defer func() {
|
||||
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
|
||||
t.Fatalf("failed to restore working directory: %v", errChdirBack)
|
||||
}
|
||||
}()
|
||||
|
||||
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
|
||||
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
|
||||
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(tmpDir, "config")
|
||||
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
|
||||
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
|
||||
}
|
||||
configPath := filepath.Join(configDir, "config.yaml")
|
||||
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
|
||||
}
|
||||
|
||||
cfg := &proxyconfig.Config{
|
||||
SDKConfig: proxyconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
},
|
||||
AuthDir: authDir,
|
||||
ErrorLogsMaxFiles: 10,
|
||||
}
|
||||
|
||||
logger := defaultRequestLoggerFactory(cfg, configPath)
|
||||
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FileRequestLogger, got %T", logger)
|
||||
}
|
||||
|
||||
errLog := fileLogger.LogRequestWithOptions(
|
||||
"/v1/chat/completions",
|
||||
http.MethodPost,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"input":"hello"}`),
|
||||
http.StatusBadGateway,
|
||||
map[string][]string{"Content-Type": []string{"application/json"}},
|
||||
[]byte(`{"error":"upstream failure"}`),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
"issue-1711",
|
||||
time.Now(),
|
||||
time.Now(),
|
||||
)
|
||||
if errLog != nil {
|
||||
t.Fatalf("failed to write forced error request log: %v", errLog)
|
||||
}
|
||||
|
||||
authLogsDir := filepath.Join(authDir, "logs")
|
||||
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
|
||||
if errReadAuthDir != nil {
|
||||
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
|
||||
}
|
||||
foundErrorLogInAuthDir := false
|
||||
for _, entry := range authEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
foundErrorLogInAuthDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundErrorLogInAuthDir {
|
||||
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
|
||||
}
|
||||
|
||||
configLogsDir := filepath.Join(configDir, "logs")
|
||||
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
|
||||
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
|
||||
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
|
||||
}
|
||||
for _, entry := range configEntries {
|
||||
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
|
||||
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Firefox fingerprint
|
||||
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
|
||||
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
|
||||
type utlsRoundTripper struct {
|
||||
// mu protects the connections map and pending map
|
||||
@@ -100,7 +100,9 @@ func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.Clie
|
||||
return h2Conn, nil
|
||||
}
|
||||
|
||||
// createConnection creates a new HTTP/2 connection with Firefox TLS fingerprint
|
||||
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
|
||||
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
|
||||
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
|
||||
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
conn, err := t.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
@@ -108,7 +110,7 @@ func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientCon
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloFirefox_Auto)
|
||||
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
@@ -156,7 +158,7 @@ func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
|
||||
// for Anthropic domains by using utls with Firefox fingerprint.
|
||||
// for Anthropic domains by using utls with Chrome fingerprint.
|
||||
// It accepts optional SDK configuration for proxy settings.
|
||||
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
|
||||
return &http.Client{
|
||||
|
||||
@@ -69,6 +69,9 @@ type Config struct {
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
|
||||
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
|
||||
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
|
||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||
|
||||
@@ -673,6 +676,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
if cfg.MaxRetryCredentials < 0 {
|
||||
cfg.MaxRetryCredentials = 0
|
||||
}
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude.","cache_control":{"type":"ephemeral"}}]
|
||||
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]
|
||||
@@ -959,22 +959,17 @@ type AntigravityModelConfig struct {
|
||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
// "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1152,7 +1152,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
continue
|
||||
}
|
||||
switch modelID {
|
||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
||||
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
|
||||
continue
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
|
||||
@@ -6,9 +6,14 @@ import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -36,7 +41,9 @@ type ClaudeExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
const claudeToolPrefix = "proxy_"
|
||||
// claudeToolPrefix is empty to match real Claude Code behavior (no tool name prefix).
|
||||
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||
const claudeToolPrefix = ""
|
||||
|
||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||
|
||||
@@ -130,6 +137,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body = ensureCacheControl(body)
|
||||
}
|
||||
|
||||
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||
// Cloaking and ensureCacheControl may push the total over 4 when the client
|
||||
// (e.g. Amp CLI) already sends multiple cache_control blocks.
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
|
||||
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -171,11 +187,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := httpResp.Body
|
||||
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return resp, err
|
||||
@@ -271,6 +305,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
body = ensureCacheControl(body)
|
||||
}
|
||||
|
||||
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
|
||||
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -312,10 +352,28 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := httpResp.Body
|
||||
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
@@ -420,6 +478,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
|
||||
body = enforceCacheControlLimit(body, 4)
|
||||
body = normalizeCacheControlTTL(body)
|
||||
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -459,9 +521,27 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := resp.Body
|
||||
if ce := resp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(resp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||
}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
b = []byte(msg)
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
if errClose := errBody.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
|
||||
@@ -696,23 +776,29 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
promptCachingBeta := "prompt-caching-2024-07-31"
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14," + promptCachingBeta
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
baseBetas += ",oauth-2025-04-20"
|
||||
}
|
||||
}
|
||||
if !strings.Contains(baseBetas, promptCachingBeta) {
|
||||
baseBetas += "," + promptCachingBeta
|
||||
|
||||
hasClaude1MHeader := false
|
||||
if ginHeaders != nil {
|
||||
if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok {
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra betas from request body
|
||||
if len(extraBetas) > 0 {
|
||||
// Merge extra betas from request body and request flags.
|
||||
if len(extraBetas) > 0 || hasClaude1MHeader {
|
||||
existingSet := make(map[string]bool)
|
||||
for _, b := range strings.Split(baseBetas, ",") {
|
||||
existingSet[strings.TrimSpace(b)] = true
|
||||
betaName := strings.TrimSpace(b)
|
||||
if betaName != "" {
|
||||
existingSet[betaName] = true
|
||||
}
|
||||
}
|
||||
for _, beta := range extraBetas {
|
||||
beta = strings.TrimSpace(beta)
|
||||
@@ -721,14 +807,16 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
existingSet[beta] = true
|
||||
}
|
||||
}
|
||||
if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] {
|
||||
baseBetas += ",context-1m-2025-08-07"
|
||||
}
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
// Values below match Claude Code 2.1.44 / @anthropic-ai/sdk 0.74.0 (captured 2026-02-17).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Helper-Method", "stream")
|
||||
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
|
||||
@@ -737,7 +825,18 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.44 (external, sdk-cli)"))
|
||||
// For User-Agent, only forward the client's header if it's already a Claude Code client.
|
||||
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
|
||||
// to avoid leaking the real client identity during cloaking.
|
||||
clientUA := ""
|
||||
if ginHeaders != nil {
|
||||
clientUA = ginHeaders.Get("User-Agent")
|
||||
}
|
||||
if isClaudeCodeClient(clientUA) {
|
||||
r.Header.Set("User-Agent", clientUA)
|
||||
} else {
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
@@ -771,22 +870,7 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
return payload
|
||||
return checkSystemInstructionsWithMode(payload, false)
|
||||
}
|
||||
|
||||
func isClaudeOAuthToken(apiKey string) bool {
|
||||
@@ -1060,33 +1144,73 @@ func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||
return payload
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code system prompt.
|
||||
// In strict mode, it replaces all user system messages.
|
||||
// In non-strict mode (default), it prepends to existing system messages.
|
||||
// generateBillingHeader creates the x-anthropic-billing-header text block that
|
||||
// real Claude Code prepends to every system prompt array.
|
||||
// Format: x-anthropic-billing-header: cc_version=<ver>.<build>; cc_entrypoint=cli; cch=<hash>;
|
||||
func generateBillingHeader(payload []byte) string {
|
||||
// Generate a deterministic cch hash from the payload content (system + messages + tools).
|
||||
// Real Claude Code uses a 5-char hex hash that varies per request.
|
||||
h := sha256.Sum256(payload)
|
||||
cch := hex.EncodeToString(h[:])[:5]
|
||||
|
||||
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
|
||||
buildBytes := make([]byte, 2)
|
||||
_, _ = rand.Read(buildBytes)
|
||||
buildHash := hex.EncodeToString(buildBytes)[:3]
|
||||
|
||||
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
|
||||
//
|
||||
// system[0]: billing header (no cache_control)
|
||||
// system[1]: agent identifier (no cache_control)
|
||||
// system[2..]: user system messages (cache_control added when missing)
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
|
||||
billingText := generateBillingHeader(payload)
|
||||
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
|
||||
// No cache_control on the agent block. It is a cloaking artifact with zero cache
|
||||
// value (the last system block is what actually triggers caching of all system content).
|
||||
// Including any cache_control here creates an intra-system TTL ordering violation
|
||||
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
|
||||
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
|
||||
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
|
||||
|
||||
if strictMode {
|
||||
// Strict mode: replace all system messages with Claude Code prompt only
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
// Strict mode: billing header + agent identifier only
|
||||
result := "[" + billingBlock + "," + agentBlock + "]"
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||
return payload
|
||||
}
|
||||
|
||||
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
// Non-strict mode: billing header + agent identifier + user system messages
|
||||
// Skip if already injected
|
||||
firstText := gjson.GetBytes(payload, "system.0.text").String()
|
||||
if strings.HasPrefix(firstText, "x-anthropic-billing-header:") {
|
||||
return payload
|
||||
}
|
||||
|
||||
result := "[" + billingBlock + "," + agentBlock
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
// Add cache_control to user system messages if not present.
|
||||
// Do NOT add ttl — let it inherit the default (5m) to avoid
|
||||
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
|
||||
partJSON := part.Raw
|
||||
if !part.Get("cache_control").Exists() {
|
||||
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
|
||||
}
|
||||
result += "," + partJSON
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
result += "]"
|
||||
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -1224,6 +1348,313 @@ func countCacheControls(payload []byte) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func parsePayloadObject(payload []byte) (map[string]any, bool) {
|
||||
if len(payload) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var root map[string]any
|
||||
if err := json.Unmarshal(payload, &root); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return root, true
|
||||
}
|
||||
|
||||
func marshalPayloadObject(original []byte, root map[string]any) []byte {
|
||||
if root == nil {
|
||||
return original
|
||||
}
|
||||
out, err := json.Marshal(root)
|
||||
if err != nil {
|
||||
return original
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func asObject(v any) (map[string]any, bool) {
|
||||
obj, ok := v.(map[string]any)
|
||||
return obj, ok
|
||||
}
|
||||
|
||||
func asArray(v any) ([]any, bool) {
|
||||
arr, ok := v.([]any)
|
||||
return arr, ok
|
||||
}
|
||||
|
||||
func countCacheControlsMap(root map[string]any) int {
|
||||
count := 0
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, item := range tools {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
|
||||
ccRaw, exists := obj["cache_control"]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
cc, ok := asObject(ccRaw)
|
||||
if !ok {
|
||||
*seen5m = true
|
||||
return
|
||||
}
|
||||
ttlRaw, ttlExists := cc["ttl"]
|
||||
ttl, ttlIsString := ttlRaw.(string)
|
||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||
*seen5m = true
|
||||
return
|
||||
}
|
||||
if *seen5m {
|
||||
delete(cc, "ttl")
|
||||
}
|
||||
}
|
||||
|
||||
func findLastCacheControlIndex(arr []any) int {
|
||||
last := -1
|
||||
for idx, item := range arr {
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
last = idx
|
||||
}
|
||||
}
|
||||
return last
|
||||
}
|
||||
|
||||
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
|
||||
for idx, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripAllCacheControl(arr []any, excess *int) {
|
||||
for _, item := range arr {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stripMessageCacheControl(messages []any, excess *int) {
|
||||
for _, msg := range messages {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if *excess <= 0 {
|
||||
return
|
||||
}
|
||||
obj, ok := asObject(item)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := obj["cache_control"]; exists {
|
||||
delete(obj, "cache_control")
|
||||
*excess--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
|
||||
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
|
||||
// appear after a 5m-TTL block anywhere in the evaluation order.
|
||||
//
|
||||
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
|
||||
// Within each section, blocks are evaluated in array order. A 5m (default) block
|
||||
// followed by a 1h block at ANY later position is an error — including within
|
||||
// the same section (e.g. system[1]=5m then system[3]=1h).
|
||||
//
|
||||
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
|
||||
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
|
||||
func normalizeCacheControlTTL(payload []byte) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
|
||||
seen5m := false
|
||||
|
||||
if tools, ok := asArray(root["tools"]); ok {
|
||||
for _, tool := range tools {
|
||||
if obj, ok := asObject(tool); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if system, ok := asArray(root["system"]); ok {
|
||||
for _, item := range system {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messages, ok := asArray(root["messages"]); ok {
|
||||
for _, msg := range messages {
|
||||
msgObj, ok := asObject(msg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := asArray(msgObj["content"])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, item := range content {
|
||||
if obj, ok := asObject(item); ok {
|
||||
normalizeTTLForBlock(obj, &seen5m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
// enforceCacheControlLimit removes excess cache_control blocks from a payload
|
||||
// so the total does not exceed the Anthropic API limit (currently 4).
|
||||
//
|
||||
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
|
||||
// The most valuable breakpoints are:
|
||||
// 1. Last tool — caches ALL tool definitions
|
||||
// 2. Last system block — caches ALL system content
|
||||
// 3. Recent messages — cache conversation context
|
||||
//
|
||||
// Removal priority (strip lowest-value first):
|
||||
//
|
||||
// Phase 1: system blocks earliest-first, preserving the last one.
|
||||
// Phase 2: tool blocks earliest-first, preserving the last one.
|
||||
// Phase 3: message content blocks earliest-first.
|
||||
// Phase 4: remaining system blocks (last system).
|
||||
// Phase 5: remaining tool blocks (last tool).
|
||||
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
|
||||
root, ok := parsePayloadObject(payload)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
|
||||
total := countCacheControlsMap(root)
|
||||
if total <= maxBlocks {
|
||||
return payload
|
||||
}
|
||||
|
||||
excess := total - maxBlocks
|
||||
|
||||
var system []any
|
||||
if arr, ok := asArray(root["system"]); ok {
|
||||
system = arr
|
||||
}
|
||||
var tools []any
|
||||
if arr, ok := asArray(root["tools"]); ok {
|
||||
tools = arr
|
||||
}
|
||||
var messages []any
|
||||
if arr, ok := asArray(root["messages"]); ok {
|
||||
messages = arr
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(messages) > 0 {
|
||||
stripMessageCacheControl(messages, &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
stripAllCacheControl(system, &excess)
|
||||
}
|
||||
if excess <= 0 {
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
stripAllCacheControl(tools, &excess)
|
||||
}
|
||||
|
||||
return marshalPayloadObject(payload, root)
|
||||
}
|
||||
|
||||
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
||||
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
||||
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -348,3 +349,237 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
|
||||
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`)
|
||||
|
||||
out := normalizeCacheControlTTL(payload)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
|
||||
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
|
||||
}
|
||||
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
|
||||
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
|
||||
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
|
||||
}
|
||||
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
|
||||
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t3","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t4","cache_control":{"type":"ephemeral"}},
|
||||
{"name":"t5","cache_control":{"type":"ephemeral"}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := enforceCacheControlLimit(payload, 4)
|
||||
|
||||
if got := countCacheControls(out); got != 4 {
|
||||
t.Fatalf("cache_control count = %d, want 4", got)
|
||||
}
|
||||
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
|
||||
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
|
||||
}
|
||||
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
|
||||
t.Fatalf("last tool cache_control should be preserved when possible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
|
||||
var seenBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
seenBody = bytes.Clone(body)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"input_tokens":42}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
|
||||
payload := []byte(`{
|
||||
"tools": [
|
||||
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"name":"t2","cache_control":{"type":"ephemeral"}}
|
||||
],
|
||||
"system": [
|
||||
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
|
||||
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
|
||||
],
|
||||
"messages": [
|
||||
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
|
||||
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-haiku-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens error: %v", err)
|
||||
}
|
||||
|
||||
if len(seenBody) == 0 {
|
||||
t.Fatal("expected count_tokens request body to be captured")
|
||||
}
|
||||
if got := countCacheControls(seenBody); got > 4 {
|
||||
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
|
||||
}
|
||||
if hasTTLOrderingViolation(seenBody) {
|
||||
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
|
||||
}
|
||||
}
|
||||
|
||||
func hasTTLOrderingViolation(payload []byte) bool {
|
||||
seen5m := false
|
||||
violates := false
|
||||
|
||||
checkCC := func(cc gjson.Result) {
|
||||
if !cc.Exists() || violates {
|
||||
return
|
||||
}
|
||||
ttl := cc.Get("ttl").String()
|
||||
if ttl != "1h" {
|
||||
seen5m = true
|
||||
return
|
||||
}
|
||||
if seen5m {
|
||||
violates = true
|
||||
}
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(payload, "tools")
|
||||
if tools.IsArray() {
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
checkCC(tool.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.IsArray() {
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, item gjson.Result) bool {
|
||||
checkCC(item.Get("cache_control"))
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
return !violates
|
||||
})
|
||||
}
|
||||
|
||||
return violates
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
|
||||
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
|
||||
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
t *testing.T,
|
||||
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
err := invoke(executor, auth, payload)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to decode error response body") {
|
||||
t.Fatalf("expected decode failure message, got: %v", err)
|
||||
}
|
||||
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
|
||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
|
||||
func generateFakeUserID() string {
|
||||
hexBytes := make([]byte, 32)
|
||||
_, _ = rand.Read(hexBytes)
|
||||
hexPart := hex.EncodeToString(hexBytes)
|
||||
uuidPart := uuid.New().String()
|
||||
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||
accountUUID := uuid.New().String()
|
||||
sessionUUID := uuid.New().String()
|
||||
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
|
||||
}
|
||||
|
||||
// isValidUserID checks if a user ID matches Claude Code format.
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format for enabled thinking
|
||||
// levels, but use thinking.type=disabled when thinking is explicitly turned off.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
@@ -17,8 +16,8 @@ import (
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Enabled thinking: reasoning_effort (string levels)
|
||||
// - Disabled thinking: thinking.type="disabled"
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
@@ -35,11 +34,19 @@ func init() {
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
// Expected output format (enabled):
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
//
|
||||
// Expected output format (disabled):
|
||||
//
|
||||
// {
|
||||
// "thinking": {
|
||||
// "type": "disabled"
|
||||
// }
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
@@ -60,8 +67,13 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
// Respect clamped fallback level for models that cannot disable thinking.
|
||||
if config.Level != "" && config.Level != thinking.LevelNone {
|
||||
effort = string(config.Level)
|
||||
break
|
||||
}
|
||||
// Kimi requires explicit disabled thinking object.
|
||||
return applyDisabledThinking(body)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
@@ -79,12 +91,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return applyReasoningEffort(body, effort)
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
@@ -101,7 +108,9 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level == "" || config.Level == thinking.LevelNone {
|
||||
return applyDisabledThinking(body)
|
||||
}
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
@@ -118,9 +127,33 @@ func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, e
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
return applyReasoningEffort(body, effort)
|
||||
}
|
||||
|
||||
func applyReasoningEffort(body []byte, effort string) ([]byte, error) {
|
||||
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||
if errDeleteThinking != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||
}
|
||||
result, errSetEffort := sjson.SetBytes(result, "reasoning_effort", effort)
|
||||
if errSetEffort != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", errSetEffort)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func applyDisabledThinking(body []byte) ([]byte, error) {
|
||||
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
|
||||
if errDeleteThinking != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
|
||||
}
|
||||
result, errDeleteEffort := sjson.DeleteBytes(result, "reasoning_effort")
|
||||
if errDeleteEffort != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to clear reasoning_effort: %w", errDeleteEffort)
|
||||
}
|
||||
result, errSetType := sjson.SetBytes(result, "thinking.type", "disabled")
|
||||
if errSetType != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set thinking.type: %w", errSetType)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
72
internal/thinking/provider/kimi/apply_test.go
Normal file
72
internal/thinking/provider/kimi/apply_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApply_ModeNone_UsesDisabledThinking(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "kimi-k2.5",
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
body := []byte(`{"model":"kimi-k2.5","reasoning_effort":"none","thinking":{"type":"enabled","budget_tokens":2048}}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_ModeLevel_UsesReasoningEffort(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "kimi-k2.5",
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
body := []byte(`{"model":"kimi-k2.5","thinking":{"type":"disabled"}}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeLevel, Level: thinking.LevelHigh}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "reasoning_effort").String(); got != "high" {
|
||||
t.Fatalf("reasoning_effort = %q, want %q, body=%s", got, "high", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "thinking").Exists() {
|
||||
t.Fatalf("thinking should be removed when reasoning_effort is used, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_UserDefinedModeNone_UsesDisabledThinking(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "custom-kimi-model",
|
||||
UserDefined: true,
|
||||
}
|
||||
body := []byte(`{"model":"custom-kimi-model","reasoning_effort":"none"}`)
|
||||
|
||||
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
|
||||
if errApply != nil {
|
||||
t.Fatalf("Apply() error = %v", errApply)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
|
||||
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
|
||||
}
|
||||
if gjson.GetBytes(out, "reasoning_effort").Exists() {
|
||||
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,11 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
||||
paths = []string{"request.generationConfig.thinkingConfig"}
|
||||
case "openai":
|
||||
paths = []string{"reasoning_effort"}
|
||||
case "kimi":
|
||||
paths = []string{
|
||||
"reasoning_effort",
|
||||
"thinking",
|
||||
}
|
||||
case "codex":
|
||||
paths = []string{"reasoning.effort"}
|
||||
case "iflow":
|
||||
|
||||
@@ -400,7 +400,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive" || thinkingType == "auto")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -440,8 +440,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
case "adaptive", "auto":
|
||||
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
@@ -46,15 +46,23 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if systemsResult.IsArray() {
|
||||
systemResults := systemsResult.Array()
|
||||
message := `{"type":"message","role":"developer","content":[]}`
|
||||
contentIndex := 0
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemResult := systemResults[i]
|
||||
systemTypeResult := systemResult.Get("type")
|
||||
if systemTypeResult.String() == "text" {
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String())
|
||||
text := systemResult.Get("text").String()
|
||||
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
continue
|
||||
}
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||
contentIndex++
|
||||
}
|
||||
}
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
if contentIndex > 0 {
|
||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages and transform their contents to appropriate formats.
|
||||
@@ -222,8 +230,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
case "adaptive", "auto":
|
||||
// Claude adaptive/auto means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "disabled":
|
||||
|
||||
@@ -180,8 +180,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
case "adaptive", "auto":
|
||||
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
@@ -161,8 +161,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
case "adaptive", "auto":
|
||||
// Keep adaptive/auto as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
@@ -75,8 +75,8 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
case "adaptive", "auto":
|
||||
// Claude adaptive/auto means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||
case "disabled":
|
||||
|
||||
@@ -127,7 +127,8 @@ func (w *Watcher) reloadConfig() bool {
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias))
|
||||
retryConfigChanged := oldConfig != nil && (oldConfig.RequestRetry != newConfig.RequestRetry || oldConfig.MaxRetryInterval != newConfig.MaxRetryInterval || oldConfig.MaxRetryCredentials != newConfig.MaxRetryCredentials)
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias) || retryConfigChanged)
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
|
||||
@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryCredentials != newCfg.MaxRetryCredentials {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-credentials: %d -> %d", oldCfg.MaxRetryCredentials, newCfg.MaxRetryCredentials))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
|
||||
@@ -223,6 +223,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
@@ -246,6 +247,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
@@ -283,6 +285,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
@@ -309,6 +312,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryCredentials: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
@@ -361,6 +365,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryCredentials: 3,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
@@ -419,6 +424,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
|
||||
@@ -1239,6 +1239,67 @@ func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReloadConfigTriggersCallbackForMaxRetryCredentialsChange(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authDir := filepath.Join(tmpDir, "auth")
|
||||
if err := os.MkdirAll(authDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
oldCfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
MaxRetryCredentials: 0,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 5,
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AuthDir: authDir,
|
||||
MaxRetryCredentials: 2,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 5,
|
||||
}
|
||||
data, errMarshal := yaml.Marshal(newCfg)
|
||||
if errMarshal != nil {
|
||||
t.Fatalf("failed to marshal config: %v", errMarshal)
|
||||
}
|
||||
if errWrite := os.WriteFile(configPath, data, 0o644); errWrite != nil {
|
||||
t.Fatalf("failed to write config: %v", errWrite)
|
||||
}
|
||||
|
||||
callbackCalls := 0
|
||||
callbackMaxRetryCredentials := -1
|
||||
w := &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
lastAuthHashes: make(map[string]string),
|
||||
reloadCallback: func(cfg *config.Config) {
|
||||
callbackCalls++
|
||||
if cfg != nil {
|
||||
callbackMaxRetryCredentials = cfg.MaxRetryCredentials
|
||||
}
|
||||
},
|
||||
}
|
||||
w.SetConfig(oldCfg)
|
||||
|
||||
if ok := w.reloadConfig(); !ok {
|
||||
t.Fatal("expected reloadConfig to succeed")
|
||||
}
|
||||
|
||||
if callbackCalls != 1 {
|
||||
t.Fatalf("expected reload callback to be called once, got %d", callbackCalls)
|
||||
}
|
||||
if callbackMaxRetryCredentials != 2 {
|
||||
t.Fatalf("expected callback MaxRetryCredentials=2, got %d", callbackMaxRetryCredentials)
|
||||
}
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
if w.config == nil || w.config.MaxRetryCredentials != 2 {
|
||||
t.Fatalf("expected watcher config MaxRetryCredentials=2, got %+v", w.config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
@@ -138,8 +138,9 @@ type Manager struct {
|
||||
providerOffsets map[string]int
|
||||
|
||||
// Retry controls request retry behavior.
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
requestRetry atomic.Int32
|
||||
maxRetryCredentials atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// oauthModelAlias stores global OAuth model alias mappings (alias -> upstream name) keyed by channel.
|
||||
oauthModelAlias atomic.Value
|
||||
@@ -384,18 +385,22 @@ func compileAPIKeyModelAliasForModels[T interface {
|
||||
}
|
||||
}
|
||||
|
||||
// SetRetryConfig updates retry attempts and cooldown wait interval.
|
||||
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
|
||||
// SetRetryConfig updates retry attempts, credential retry limit and cooldown wait interval.
|
||||
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration, maxRetryCredentials int) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if retry < 0 {
|
||||
retry = 0
|
||||
}
|
||||
if maxRetryCredentials < 0 {
|
||||
maxRetryCredentials = 0
|
||||
}
|
||||
if maxRetryInterval < 0 {
|
||||
maxRetryInterval = 0
|
||||
}
|
||||
m.requestRetry.Store(int32(retry))
|
||||
m.maxRetryCredentials.Store(int32(maxRetryCredentials))
|
||||
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
|
||||
}
|
||||
|
||||
@@ -506,11 +511,11 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
|
||||
_, maxWait := m.retrySettings()
|
||||
_, maxRetryCredentials, maxWait := m.retrySettings()
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -537,11 +542,11 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
|
||||
_, maxWait := m.retrySettings()
|
||||
_, maxRetryCredentials, maxWait := m.retrySettings()
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -568,11 +573,11 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
|
||||
_, maxWait := m.retrySettings()
|
||||
_, maxRetryCredentials, maxWait := m.retrySettings()
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; ; attempt++ {
|
||||
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
|
||||
if errStream == nil {
|
||||
return result, nil
|
||||
}
|
||||
@@ -591,7 +596,7 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
@@ -600,6 +605,12 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
@@ -647,7 +658,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
@@ -656,6 +667,12 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
@@ -703,7 +720,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
@@ -712,6 +729,12 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
@@ -1108,11 +1131,11 @@ func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
func (m *Manager) retrySettings() (int, int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
return 0, 0, 0
|
||||
}
|
||||
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||
return int(m.requestRetry.Load()), int(m.maxRetryCredentials.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||
}
|
||||
|
||||
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
|
||||
|
||||
@@ -2,13 +2,17 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetRetryConfig(3, 30*time.Second)
|
||||
m.SetRetryConfig(3, 30*time.Second, 0)
|
||||
|
||||
model := "test-model"
|
||||
next := time.Now().Add(5 * time.Second)
|
||||
@@ -31,7 +35,7 @@ func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testi
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
_, maxWait := m.retrySettings()
|
||||
_, _, maxWait := m.retrySettings()
|
||||
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
|
||||
if shouldRetry {
|
||||
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
|
||||
@@ -56,6 +60,124 @@ func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
type credentialRetryLimitExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) Identifier() string {
|
||||
return e.id
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
e.recordCall()
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "boom"}
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
e.recordCall()
|
||||
return nil, &Error{HTTPStatus: 500, Message: "boom"}
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
e.recordCall()
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "boom"}
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) recordCall() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.calls++
|
||||
}
|
||||
|
||||
func (e *credentialRetryLimitExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
|
||||
t.Helper()
|
||||
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetRetryConfig(0, 0, maxRetryCredentials)
|
||||
|
||||
executor := &credentialRetryLimitExecutor{id: "claude"}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &Auth{ID: "auth-1", Provider: "claude"}
|
||||
auth2 := &Auth{ID: "auth-2", Provider: "claude"}
|
||||
if _, errRegister := m.Register(context.Background(), auth1); errRegister != nil {
|
||||
t.Fatalf("register auth1: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), auth2); errRegister != nil {
|
||||
t.Fatalf("register auth2: %v", errRegister)
|
||||
}
|
||||
|
||||
return m, executor
|
||||
}
|
||||
|
||||
func TestManager_MaxRetryCredentials_LimitsCrossCredentialRetries(t *testing.T) {
|
||||
request := cliproxyexecutor.Request{Model: "test-model"}
|
||||
testCases := []struct {
|
||||
name string
|
||||
invoke func(*Manager) error
|
||||
}{
|
||||
{
|
||||
name: "execute",
|
||||
invoke: func(m *Manager) error {
|
||||
_, errExecute := m.Execute(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
return errExecute
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "execute_count",
|
||||
invoke: func(m *Manager) error {
|
||||
_, errExecute := m.ExecuteCount(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
return errExecute
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "execute_stream",
|
||||
invoke: func(m *Manager) error {
|
||||
_, errExecute := m.ExecuteStream(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
|
||||
return errExecute
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
limitedManager, limitedExecutor := newCredentialRetryLimitTestManager(t, 1)
|
||||
if errInvoke := tc.invoke(limitedManager); errInvoke == nil {
|
||||
t.Fatalf("expected error for limited retry execution")
|
||||
}
|
||||
if calls := limitedExecutor.Calls(); calls != 1 {
|
||||
t.Fatalf("expected 1 call with max-retry-credentials=1, got %d", calls)
|
||||
}
|
||||
|
||||
unlimitedManager, unlimitedExecutor := newCredentialRetryLimitTestManager(t, 0)
|
||||
if errInvoke := tc.invoke(unlimitedManager); errInvoke == nil {
|
||||
t.Fatalf("expected error for unlimited retry execution")
|
||||
}
|
||||
if calls := unlimitedExecutor.Calls(); calls != 2 {
|
||||
t.Fatalf("expected 2 calls with max-retry-credentials=0, got %d", calls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||
prev := quotaCooldownDisabled.Load()
|
||||
quotaCooldownDisabled.Store(false)
|
||||
|
||||
@@ -347,7 +347,7 @@ func (s *Service) applyRetryConfig(cfg *config.Config) {
|
||||
return
|
||||
}
|
||||
maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second
|
||||
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval)
|
||||
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval, cfg.MaxRetryCredentials)
|
||||
}
|
||||
|
||||
func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) {
|
||||
|
||||
Reference in New Issue
Block a user