Compare commits

...

55 Commits

Author SHA1 Message Date
Luis Pater
4607356333 Merge pull request #491 from Ve-ria/main
修复 CodeBuddy 不支持非流式请求的问题
2026-04-07 18:25:21 +08:00
Luis Pater
9a9ed99072 Merge pull request #494 from router-for-me/plus
v6.9.16
2026-04-07 18:23:51 +08:00
Luis Pater
5ae38584b8 Merge branch 'main' into plus 2026-04-07 18:23:31 +08:00
Luis Pater
c8b7e2b8d6 fix(executor): ensure empty stream completions use output_item.done as fallback
Fixed: #2583
2026-04-07 18:21:12 +08:00
Luis Pater
cad45ffa33 Merge pull request #2578 from LemonZuo/feat_socks5h
feat: support socks5h scheme for proxy settings
2026-04-07 09:57:18 +08:00
Luis Pater
6a27bceec0 Merge pull request #2576 from zilianpn/fix/disable-cooling-auth-errors
fix(auth): honor disable-cooling and enrich no-auth errors
2026-04-07 09:56:25 +08:00
Lemon
163d68318f feat: support socks5h scheme for proxy settings 2026-04-07 07:46:11 +08:00
zilianpn
0ea768011b fix(auth): honor disable-cooling and enrich no-auth errors 2026-04-07 01:12:13 +08:00
rensumo
341b4beea1 Update internal/runtime/executor/codebuddy_executor.go
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-04-06 14:16:56 +08:00
rensumo
bea13f9724 fix(executor): support non-stream requests for CodeBuddy 2026-04-06 13:59:06 +08:00
Luis Pater
9f5bdfaa31 Merge pull request #2531 from jamestut/openai-vertex-token-usage-fix
Fix missing `response.completed.usage` for late-usage OpenAI-compatible streams
2026-04-06 09:30:49 +08:00
Luis Pater
9eabdd09db Merge pull request #2522 from aikins01/fix/strip-tool-use-signature
fix(amp): strip signature from tool_use blocks before forwarding to Claude
2026-04-06 09:30:14 +08:00
Luis Pater
c3f8dc362e Merge pull request #2491 from mpfo0106/feature/claude-code-safe-alignment-sentinels
test(claude): add compatibility sentinels and centralize builtin fallback handling
2026-04-06 09:27:08 +08:00
Luis Pater
b85120873b Merge pull request #2332 from RaviTharuma/fix/claude-thinking-signature
fix: preserve Claude thinking signatures in Codex translator
2026-04-06 09:25:06 +08:00
Luis Pater
6f58518c69 docs(readme): remove redundant GITSTORE_GIT_BRANCH description in README files 2026-04-06 09:23:04 +08:00
Luis Pater
000fcb15fa Merge pull request #2298 from snoyiatk/feat/add-gitstore-branch
feat(gitstore): add support for specifying git branch (via GITSTORE_G…
2026-04-06 09:21:03 +08:00
Luis Pater
ea43361492 Merge pull request #2121 from destinoantagonista-wq/main
Reconcile registry model states on auth changes
2026-04-06 09:13:27 +08:00
Luis Pater
c1818f197b Merge pull request #1940 from Blue-B/fix/claude-interleaved-thinking-amp-gzip-budget
fix(claude): enable interleaved-thinking beta, decode AMP error gzip, fix budget 400
2026-04-06 09:08:23 +08:00
Aikins Laryea
b0653cec7b fix(amp): strip signature from tool_use blocks before forwarding to Claude
ensureAmpSignature injects signature:"" into tool_use blocks so the
Amp TUI does not crash on P.signature.length. when Amp sends the
conversation back, Claude rejects the extra field with 400:
  tool_use.signature: Extra inputs are not permitted

strip the proxy-injected signature from tool_use blocks in
SanitizeAmpRequestBody before forwarding to the upstream API.
2026-04-05 12:26:24 +00:00
Luis Pater
22a1a24cf5 feat(executor): add tests for preserving key order in cache control functions
Added comprehensive tests to ensure key order is maintained when modifying payloads in `normalizeCacheControlTTL` and `enforceCacheControlLimit` functions. Removed unused helper functions and refactored implementations for better readability and efficiency.
2026-04-05 17:58:13 +08:00
Luis Pater
7223fee2de Merge branch 'pr-488'
# Conflicts:
#	README.md
#	README_CN.md
#	README_JA.md
2026-04-05 02:08:45 +08:00
Luis Pater
ada8e2905e feat(api): enhance proxy resolution for API key-based auth
Added comprehensive support for resolving proxy URLs from configuration based on API key and provider attributes. Introduced new helper functions and extended the test suite to validate fallback mechanisms and compatibility cases.
2026-04-05 01:56:34 +08:00
Luis Pater
4ba10531da feat(docs): add Poixe AI sponsorship details to README files
Added Poixe AI sponsorship information, including referral bonuses and platform capabilities, to README files in English, Japanese, and Chinese. Updated assets to include Poixe AI logo.
2026-04-05 01:20:50 +08:00
Luis Pater
3774b56e9f feat(misc): add background updater for Antigravity version caching
Introduce `StartAntigravityVersionUpdater` to periodically refresh the cached Antigravity version using a non-blocking background process. Updated main server flow to initialize the updater.
2026-04-04 22:09:11 +08:00
Luis Pater
c2d4137fb9 feat(executor): enhance Qwen system message handling with strict injection and merging rules
Closes: #2537
2026-04-04 21:51:02 +08:00
Luis Pater
2ee938acaf Merge pull request #2535 from rensumo/main
feat: 动态获取 Antigravity User-Agent 版本号
2026-04-04 21:00:47 +08:00
rensumo
8d5e470e1f feat: dynamically fetch antigravity UA version from releases API
Fetch the latest version from the antigravity auto-updater releases
endpoint and cache it for 6 hours. Falls back to 1.21.9 if the API
is unreachable or returns unexpected data.
2026-04-04 14:52:59 +08:00
James
65e9e892a4 Fix missing response.completed.usage for late-usage OpenAI-compatible streams 2026-04-04 05:58:04 +00:00
Luis Pater
3882494878 Merge pull request #486 from router-for-me/plus
v6.9.14
2026-04-04 11:40:13 +08:00
Luis Pater
088c1d07f4 Merge branch 'main' into plus 2026-04-04 11:40:03 +08:00
Luis Pater
8430b28cfa Merge pull request #2526 from rensumo/main
feat: 升级反重力 (antigravity) UA 版本为 1.21.9
2026-04-04 11:32:16 +08:00
rensumo
f3ab8f4bc5 chore: update antigravity UA version to 1.21.9 2026-04-04 07:35:08 +08:00
Luis Pater
0e4f189c2e Merge pull request #1302 from dinhkarate/feat(vertex)/add-prefix-field
Feat(vertex): add prefix field
2026-04-04 04:17:12 +08:00
Luis Pater
98509f615c Merge pull request #485 from kunish/fix/copilot-premium-request-inflation
fix(copilot): reduce premium request inflation, enable thinking, and use dynamic API limits
2026-04-04 02:19:56 +08:00
kunish
87bf0b73d5 fix(copilot): use dynamic API limits to prevent prompt token overflow
The Copilot API enforces per-account prompt token limits (128K individual,
168K business) that differ from the static 200K context length advertised
by the proxy. This mismatch caused Claude Code to accumulate context
beyond the actual limit, triggering "prompt token count exceeds the limit
of 128000" errors.

Changes:
- Extract max_prompt_tokens and max_output_tokens from the Copilot
  /models API response (capabilities.limits) and use them as the
  authoritative ContextLength and MaxCompletionTokens values
- Add CopilotModelLimits struct and Limits() helper to parse limits
  from the existing Capabilities map
- Fix GitLab Duo context-1m beta header not being set when routing
  through the Anthropic gateway (gitlab_duo_force_context_1m attr
  was set but only gin headers were checked)
- Fix flaky parallel tests that shared global model registry state
2026-04-03 23:54:17 +08:00
kunish
b849bf79d6 fix(copilot): address code review — SSE reasoning, multi-choice, agent detection
- Strip SSE `data:` prefix before normalizing reasoning_text→reasoning_content
  in streaming mode; re-wrap afterward for the translator
- Iterate all choices in normalizeGitHubCopilotReasoningField (not just
  choices[0]) to support n>1 requests
- Remove over-broad tool-role fallback in isAgentInitiated that scanned
  all messages for role:"tool", aligning with opencode's approach of only
  detecting active tool loops — genuine user follow-ups after tool use are
  no longer mis-classified as agent-initiated
- Add 5 reasoning normalization tests; update 2 X-Initiator tests to match
  refined semantics
2026-04-03 20:51:19 +08:00
kunish
59af2c57b1 fix(copilot): reduce premium request inflation and enable thinking
This commit addresses three issues with Claude Code through GitHub
Copilot:

1. **Premium request inflation**: Responses API requests were missing
   Openai-Intent headers and proper defaults, causing Copilot to bill
   each tool-loop continuation as a new premium request. Fixed by adding
   isAgentInitiated() heuristic (checks for tool_result content or
   preceding assistant tool_use), applying Responses API defaults
   (store, include, reasoning.summary), and local tiktoken-based token
   counting to avoid extra API calls.

2. **Context overflow**: Claude Code's modelSupports1M() hardcodes
   opus-4-6 as 1M-capable, but Copilot only supports ~128K-200K.
   Fixed by stripping the context-1m-2025-08-07 beta from translated
   request bodies. Also forwards response headers in non-streaming
   Execute() and registers the GET /copilot-quota management API route.

3. **Thinking not working**: Add ThinkingSupport with level-based
   reasoning to Claude models in the static definitions. Normalize
   Copilot's non-standard 'reasoning_text' response field to
   'reasoning_content' before passing to the SDK translator. Use
   caller-provided context in CountTokens instead of Background().
2026-04-03 20:24:30 +08:00
mpfo0106
9b5ce8c64f Keep Claude builtin helpers aligned with the shared helper layout
The review asked for the builtin tool registry helper to live with the rest
of executor support utilities. This moves the registry code into the helps
package, exports the minimal surface executor needs, and keeps behavior tests
with the executor while leaving registry-focused checks with the helper.

Constraint: Requested layout keeps executor helper utilities centralized under internal/runtime/executor/helps
Rejected: Keep the files in executor and reply with rationale | conflicts with requested package organization
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Keep executor behavior tests near applyClaudeToolPrefix and keep pure registry tests in helps
Tested: go test ./internal/runtime/executor/helps ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./test/...; go test ./...
Not-tested: End-to-end Claude Code direct-connect/session runtime behavior
2026-04-03 00:13:02 +09:00
Duong M. CUONG
058793c73a feat(gitstore): honor configured branch and follow live remote default 2026-04-02 14:44:44 +00:00
mpfo0106
da3a498a28 Keep Claude Code compatibility work low-risk and reviewable
This change stops short of broader Claude Code runtime alignment and instead
hardens two safe edges: builtin tool prefix handling and source-informed
sentinel coverage for future drift checks.

Constraint: Must preserve existing default behavior for current users
Rejected: Implement control-plane/session alignment now | too much runtime risk for a first slice
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Treat the new fixtures as compatibility sentinels, not a full Claude Code schema contract
Tested: go test ./test/...; go test ./sdk/translator/...; go test ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./...
Not-tested: End-to-end Claude Code direct-connect/session runtime behavior
2026-04-02 20:35:39 +09:00
Ravi Tharuma
5fc2bd393e fix: retain codex thinking signature until item done 2026-03-28 14:41:25 +01:00
Ravi Tharuma
66eb12294a fix: clear stale thinking signature when no block is open 2026-03-28 14:08:31 +01:00
Ravi Tharuma
73b22ec29b fix: omit empty signature field from thinking blocks
Emit signature only when non-empty in both streaming content_block_start
and non-streaming thinking blocks. Avoids turning 'missing signature'
into 'empty/invalid signature' which Claude clients may reject.
2026-03-28 14:08:31 +01:00
Ravi Tharuma
c31ae2f3b5 fix: retain previously captured thinking signature on new summary part 2026-03-28 14:08:31 +01:00
Ravi Tharuma
76b53d6b5b fix: finalize pending thinking block before next summary part 2026-03-28 14:08:31 +01:00
Ravi Tharuma
a34dfed378 fix: preserve Claude thinking signatures in Codex translator 2026-03-28 14:08:31 +01:00
dinhkarate
36efcc6e28 fix(vertex): include prefix in auth filename and validate at import
Address two blocking issues from PR review:
- Auth file now named vertex-{prefix}-{project}.json so importing the
  same project with different prefixes no longer overwrites credentials
- Prefix containing "/" is rejected at import time instead of being
  silently ignored at runtime
- Add prefix to in-memory metadata map for consistency

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 15:06:04 +07:00
Pham Quang Dinh
a337ecf35c Merge branch 'router-for-me:main' into feat(vertex)/add-prefix-field 2026-03-17 11:48:40 +07:00
destinoantagonista-wq
e08f68ed7c chore(auth): drop reconcile test file from pr 2026-03-14 14:41:26 +00:00
destinoantagonista-wq
f09ed25fd3 fix(auth): tighten registry model reconciliation 2026-03-14 14:40:06 +00:00
destinoantagonista-wq
e166e56249 Reconcile registry model states on auth changes
Add Manager.ReconcileRegistryModelStates to clear stale per-model runtime failures for models currently registered in the global model registry. The method finds models supported for an auth, resets non-clean ModelState entries, updates aggregated availability, persists changes, and pushes a snapshot to the scheduler. Introduce modelStateIsClean helper to determine when a model state needs resetting. Call ReconcileRegistryModelStates from Service paths that register/refresh models (applyCoreAuthAddOrUpdate and refreshModelRegistrationForAuth) to keep the scheduler and global registry aligned after model re-registration.
2026-03-13 19:41:49 +00:00
Blue-B
5f58248016 fix(claude): clamp max_tokens to model limit in normalizeClaudeBudget
When adjustedBudget < minBudget, the previous fix blindly set
max_tokens = budgetTokens+1 which could exceed MaxCompletionTokens.

Now: cap max_tokens at MaxCompletionTokens, recalculate budget, and
disable thinking entirely if constraints are unsatisfiable.

Add unit tests covering raise, clamp, disable, and no-op scenarios.
2026-03-09 22:10:30 +09:00
Blue-B
07d6689d87 fix(claude): add interleaved-thinking beta header, AMP gzip error decoding, normalizeClaudeBudget max_tokens
1. Always include interleaved-thinking-2025-05-14 beta header so that
   thinking blocks are returned correctly for all Claude models.

2. Remove status-code guard in AMP reverse proxy ModifyResponse so that
   error responses (4xx/5xx) with hidden gzip encoding are decoded
   properly — prevents garbled error messages reaching the client.

3. In normalizeClaudeBudget, when the adjusted budget falls below the
   model minimum, set max_tokens = budgetTokens+1 instead of leaving
   the request unchanged (which causes a 400 from the API).
2026-03-07 21:31:10 +09:00
dinhkarate
14cb2b95c6 feat(vertex): add --vertex-import-prefix flag for model namespacing 2026-01-29 13:32:38 +07:00
dinhkarate
fdeef48498 feat(vertex): Add Prefix field to VertexCredentialStorage for per-file model namespacing 2026-01-29 13:32:38 +07:00
49 changed files with 4401 additions and 576 deletions

6
.gitignore vendored
View File

@@ -54,4 +54,10 @@ _bmad-output/*
# macOS
.DS_Store
._*
# Opencode
.beads/
.opencode/
.cli-proxy-api/
.venv/
*.bak

View File

@@ -26,6 +26,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
@@ -188,7 +189,7 @@ func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
httpReq.Header.Set("User-Agent", misc.AntigravityUserAgent())
httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {

View File

@@ -99,6 +99,7 @@ func main() {
var codeBuddyLogin bool
var projectID string
var vertexImport string
var vertexImportPrefix string
var configPath string
var password string
var tuiMode bool
@@ -139,6 +140,7 @@ func main() {
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
@@ -188,6 +190,7 @@ func main() {
gitStoreRemoteURL string
gitStoreUser string
gitStorePassword string
gitStoreBranch string
gitStoreLocalPath string
gitStoreInst *store.GitTokenStore
gitStoreRoot string
@@ -257,6 +260,9 @@ func main() {
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
gitStoreLocalPath = value
}
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
gitStoreBranch = value
}
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
useObjectStore = true
objectStoreEndpoint = value
@@ -391,7 +397,7 @@ func main() {
}
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
authDir := filepath.Join(gitStoreRoot, "auths")
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Errorf("failed to prepare git token store: %v", errRepo)
@@ -510,7 +516,7 @@ func main() {
if vertexImport != "" {
// Handle Vertex service account import
cmd.DoVertexImport(cfg, vertexImport)
cmd.DoVertexImport(cfg, vertexImport, vertexImportPrefix)
} else if login {
// Handle Google/Gemini login
cmd.DoLogin(cfg, projectID, options)
@@ -596,6 +602,7 @@ func main() {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
misc.StartAntigravityVersionUpdater(context.Background())
if !localModel {
registry.StartModelsUpdater(context.Background())
}
@@ -671,6 +678,7 @@ func main() {
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
misc.StartAntigravityVersionUpdater(context.Background())
if !localModel {
registry.StartModelsUpdater(context.Background())
}

View File

@@ -92,6 +92,9 @@ max-retry-credentials: 0
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
disable-cooling: false
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded

View File

@@ -13,6 +13,7 @@ import (
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
@@ -700,6 +701,11 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(proxyURLFromAPIKeyConfig(h.cfg, auth)); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
@@ -722,6 +728,123 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
return clone
}
type apiKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}
func resolveAPIKeyConfig[T apiKeyConfigEntry](entries []T, auth *coreauth.Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}
func proxyURLFromAPIKeyConfig(cfg *config.Config, auth *coreauth.Auth) string {
if cfg == nil || auth == nil {
return ""
}
authKind, authAccount := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(authKind), "api_key") {
return ""
}
attrs := auth.Attributes
compatName := ""
providerKey := ""
if len(attrs) > 0 {
compatName = strings.TrimSpace(attrs["compat_name"])
providerKey = strings.TrimSpace(attrs["provider_key"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return resolveOpenAICompatAPIKeyProxyURL(cfg, auth, strings.TrimSpace(authAccount), providerKey, compatName)
}
switch strings.ToLower(strings.TrimSpace(auth.Provider)) {
case "gemini":
if entry := resolveAPIKeyConfig(cfg.GeminiKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "claude":
if entry := resolveAPIKeyConfig(cfg.ClaudeKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
case "codex":
if entry := resolveAPIKeyConfig(cfg.CodexKey, auth); entry != nil {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}
func resolveOpenAICompatAPIKeyProxyURL(cfg *config.Config, auth *coreauth.Auth, apiKey, providerKey, compatName string) string {
if cfg == nil || auth == nil {
return ""
}
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return ""
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(auth.Provider); v != "" {
candidates = append(candidates, v)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
for j := range compat.APIKeyEntries {
entry := &compat.APIKeyEntries[j]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), apiKey) {
return strings.TrimSpace(entry.ProxyURL)
}
}
return ""
}
}
}
return ""
}
func buildProxyTransport(proxyStr string) *http.Transport {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {

View File

@@ -58,6 +58,105 @@ func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
}
}
func TestAPICallTransportAPIKeyAuthFallsBackToConfigProxyURL(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
GeminiKey: []config.GeminiKey{{
APIKey: "gemini-key",
ProxyURL: "http://gemini-proxy.example.com:8080",
}},
ClaudeKey: []config.ClaudeKey{{
APIKey: "claude-key",
ProxyURL: "http://claude-proxy.example.com:8080",
}},
CodexKey: []config.CodexKey{{
APIKey: "codex-key",
ProxyURL: "http://codex-proxy.example.com:8080",
}},
OpenAICompatibility: []config.OpenAICompatibility{{
Name: "bohe",
BaseURL: "https://bohe.example.com",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{
APIKey: "compat-key",
ProxyURL: "http://compat-proxy.example.com:8080",
}},
}},
},
}
cases := []struct {
name string
auth *coreauth.Auth
wantProxy string
}{
{
name: "gemini",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: map[string]string{"api_key": "gemini-key"},
},
wantProxy: "http://gemini-proxy.example.com:8080",
},
{
name: "claude",
auth: &coreauth.Auth{
Provider: "claude",
Attributes: map[string]string{"api_key": "claude-key"},
},
wantProxy: "http://claude-proxy.example.com:8080",
},
{
name: "codex",
auth: &coreauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "codex-key"},
},
wantProxy: "http://codex-proxy.example.com:8080",
},
{
name: "openai-compatibility",
auth: &coreauth.Auth{
Provider: "bohe",
Attributes: map[string]string{
"api_key": "compat-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
},
wantProxy: "http://compat-proxy.example.com:8080",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
transport := h.apiCallTransport(tc.auth)
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != tc.wantProxy {
t.Fatalf("proxy URL = %v, want %s", proxyURL, tc.wantProxy)
}
})
}
}
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel()

View File

@@ -2,6 +2,7 @@ package amp
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
@@ -298,8 +299,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// from the messages array in a request body before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
// array before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking
// blocks and does not accept a signature field on tool_use blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
@@ -317,21 +320,30 @@ func SanitizeAmpRequestBody(body []byte) []byte {
}
var keepBlocks []interface{}
removedCount := 0
contentModified := false
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
removedCount++
contentModified = true
continue
}
}
keepBlocks = append(keepBlocks, block.Value())
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
blockRaw := []byte(block.Raw)
if blockType == "tool_use" && block.Get("signature").Exists() {
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
contentModified = true
}
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
}
if removedCount > 0 {
if contentModified {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
@@ -340,11 +352,10 @@ func SanitizeAmpRequestBody(body []byte) []byte {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
continue
}
modified = true
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
}
}

View File

@@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi
}
}
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte(`"signature":""`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"valid-sig"`)) {
t.Fatalf("expected thinking signature to remain, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-me")) {
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
}
if contains(result, []byte(`"signature"`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -573,6 +573,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.GET("/copilot-quota", s.mgmt.GetCopilotQuota)
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)

View File

@@ -235,6 +235,74 @@ type CopilotModelEntry struct {
Capabilities map[string]any `json:"capabilities,omitempty"`
}
// CopilotModelLimits holds the token limits returned by the Copilot /models API
// under capabilities.limits. These limits vary by account type (individual vs
// business) and are the authoritative source for enforcing prompt size.
type CopilotModelLimits struct {
// MaxContextWindowTokens is the total context window (prompt + output).
MaxContextWindowTokens int
// MaxPromptTokens is the hard limit on input/prompt tokens.
// Exceeding this triggers a 400 error from the Copilot API.
MaxPromptTokens int
// MaxOutputTokens is the maximum number of output/completion tokens.
MaxOutputTokens int
}
// Limits extracts the token limits from the model's capabilities map.
// Returns nil if no limits are available or the structure is unexpected.
//
// Expected Copilot API shape:
//
// "capabilities": {
// "limits": {
// "max_context_window_tokens": 200000,
// "max_prompt_tokens": 168000,
// "max_output_tokens": 32000
// }
// }
func (e *CopilotModelEntry) Limits() *CopilotModelLimits {
if e.Capabilities == nil {
return nil
}
limitsRaw, ok := e.Capabilities["limits"]
if !ok {
return nil
}
limitsMap, ok := limitsRaw.(map[string]any)
if !ok {
return nil
}
result := &CopilotModelLimits{
MaxContextWindowTokens: anyToInt(limitsMap["max_context_window_tokens"]),
MaxPromptTokens: anyToInt(limitsMap["max_prompt_tokens"]),
MaxOutputTokens: anyToInt(limitsMap["max_output_tokens"]),
}
// Only return if at least one field is populated.
if result.MaxContextWindowTokens == 0 && result.MaxPromptTokens == 0 && result.MaxOutputTokens == 0 {
return nil
}
return result
}
// anyToInt converts a JSON-decoded numeric value to int.
// Go's encoding/json decodes numbers into float64 when the target is any/interface{}.
func anyToInt(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case float32:
return int(n)
case int:
return n
case int64:
return int(n)
default:
return 0
}
}
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
type CopilotModelsResponse struct {
Data []CopilotModelEntry `json:"data"`

View File

@@ -30,6 +30,10 @@ type VertexCredentialStorage struct {
// Type is the provider identifier stored alongside credentials. Always "vertex".
Type string `json:"type"`
// Prefix optionally namespaces models for this credential (e.g., "teamA").
// This results in model names like "teamA/gemini-2.0-flash".
Prefix string `json:"prefix,omitempty"`
}
// SaveTokenToFile writes the credential payload to the given file path in JSON format.

View File

@@ -20,7 +20,7 @@ import (
// DoVertexImport imports a Google Cloud service account key JSON and persists
// it as a "vertex" provider credential. The file content is embedded in the auth
// file to allow portable deployment across stores.
func DoVertexImport(cfg *config.Config, keyPath string) {
func DoVertexImport(cfg *config.Config, keyPath string, prefix string) {
if cfg == nil {
cfg = &config.Config{}
}
@@ -62,13 +62,28 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
// Default location if not provided by user. Can be edited in the saved file later.
location := "us-central1"
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
// Normalize and validate prefix: must be a single segment (no "/" allowed).
prefix = strings.TrimSpace(prefix)
prefix = strings.Trim(prefix, "/")
if prefix != "" && strings.Contains(prefix, "/") {
log.Errorf("vertex-import: prefix must be a single segment (no '/' allowed): %q", prefix)
return
}
// Include prefix in filename so importing the same project with different
// prefixes creates separate credential files instead of overwriting.
baseName := sanitizeFilePart(projectID)
if prefix != "" {
baseName = sanitizeFilePart(prefix) + "-" + baseName
}
fileName := fmt.Sprintf("vertex-%s.json", baseName)
// Build auth record
storage := &vertex.VertexCredentialStorage{
ServiceAccount: sa,
ProjectID: projectID,
Email: email,
Location: location,
Prefix: prefix,
}
metadata := map[string]any{
"service_account": sa,
@@ -76,6 +91,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
"email": email,
"location": location,
"type": "vertex",
"prefix": prefix,
"label": labelForVertex(projectID, email),
}
record := &coreauth.Auth{

View File

@@ -0,0 +1,151 @@
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
package misc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
antigravityReleasesURL = "https://antigravity-auto-updater-974169037036.us-central1.run.app/releases"
antigravityFallbackVersion = "1.21.9"
antigravityVersionCacheTTL = 6 * time.Hour
antigravityFetchTimeout = 10 * time.Second
)
type antigravityRelease struct {
Version string `json:"version"`
ExecutionID string `json:"execution_id"`
}
var (
cachedAntigravityVersion = antigravityFallbackVersion
antigravityVersionMu sync.RWMutex
antigravityVersionExpiry time.Time
antigravityUpdaterOnce sync.Once
)
// StartAntigravityVersionUpdater starts a background goroutine that periodically refreshes the cached antigravity version.
// This is intentionally decoupled from request execution to avoid blocking executors on version lookups.
func StartAntigravityVersionUpdater(ctx context.Context) {
antigravityUpdaterOnce.Do(func() {
go runAntigravityVersionUpdater(ctx)
})
}
func runAntigravityVersionUpdater(ctx context.Context) {
if ctx == nil {
ctx = context.Background()
}
ticker := time.NewTicker(antigravityVersionCacheTTL / 2)
defer ticker.Stop()
log.Infof("periodic antigravity version refresh started (interval=%s)", antigravityVersionCacheTTL/2)
refreshAntigravityVersion(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
refreshAntigravityVersion(ctx)
}
}
}
func refreshAntigravityVersion(ctx context.Context) {
version, errFetch := fetchAntigravityLatestVersion(ctx)
antigravityVersionMu.Lock()
defer antigravityVersionMu.Unlock()
now := time.Now()
if errFetch == nil {
cachedAntigravityVersion = version
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
log.WithField("version", version).Info("fetched latest antigravity version")
return
}
if cachedAntigravityVersion == "" || now.After(antigravityVersionExpiry) {
cachedAntigravityVersion = antigravityFallbackVersion
antigravityVersionExpiry = now.Add(antigravityVersionCacheTTL)
log.WithError(errFetch).Warn("failed to refresh antigravity version, using fallback version")
return
}
log.WithError(errFetch).Debug("failed to refresh antigravity version, keeping cached value")
}
// AntigravityLatestVersion returns the cached antigravity version refreshed by StartAntigravityVersionUpdater.
// It falls back to antigravityFallbackVersion if the cache is empty or stale.
func AntigravityLatestVersion() string {
antigravityVersionMu.RLock()
if cachedAntigravityVersion != "" && time.Now().Before(antigravityVersionExpiry) {
v := cachedAntigravityVersion
antigravityVersionMu.RUnlock()
return v
}
antigravityVersionMu.RUnlock()
return antigravityFallbackVersion
}
// AntigravityUserAgent returns the User-Agent string for antigravity requests
// using the latest version fetched from the releases API.
func AntigravityUserAgent() string {
return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion())
}
func fetchAntigravityLatestVersion(ctx context.Context) (string, error) {
if ctx == nil {
ctx = context.Background()
}
client := &http.Client{Timeout: antigravityFetchTimeout}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodGet, antigravityReleasesURL, nil)
if errReq != nil {
return "", fmt.Errorf("build antigravity releases request: %w", errReq)
}
resp, errDo := client.Do(httpReq)
if errDo != nil {
return "", fmt.Errorf("fetch antigravity releases: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.WithError(errClose).Warn("antigravity releases response body close error")
}
}()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("antigravity releases API returned status %d", resp.StatusCode)
}
var releases []antigravityRelease
if errDecode := json.NewDecoder(resp.Body).Decode(&releases); errDecode != nil {
return "", fmt.Errorf("decode antigravity releases response: %w", errDecode)
}
if len(releases) == 0 {
return "", errors.New("antigravity releases API returned empty list")
}
version := releases[0].Version
if version == "" {
return "", errors.New("antigravity releases API returned empty version")
}
return version, nil
}

View File

@@ -549,6 +549,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "claude-opus-4.6",
@@ -561,6 +562,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "claude-sonnet-4",
@@ -573,6 +575,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "claude-sonnet-4.5",
@@ -585,6 +588,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "claude-sonnet-4.6",
@@ -597,6 +601,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gemini-2.5-pro",

View File

@@ -24,6 +24,7 @@ import (
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -45,7 +46,7 @@ const (
antigravityGeneratePath = "/v1internal:generateContent"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent()
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
antigravityCreditsRetryTTL = 5 * time.Hour
@@ -1739,7 +1740,7 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
}
}
}
return defaultAntigravityAgent
return misc.AntigravityUserAgent()
}
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {

View File

@@ -8,7 +8,6 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -841,6 +840,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
baseBetas += ",oauth-2025-04-20"
}
}
if !strings.Contains(baseBetas, "interleaved-thinking") {
baseBetas += ",interleaved-thinking-2025-05-14"
}
hasClaude1MHeader := false
if ginHeaders != nil {
@@ -848,6 +850,14 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
hasClaude1MHeader = true
}
}
// Also check auth attributes — GitLab Duo sets gitlab_duo_force_context_1m
// when routing through the Anthropic gateway, but the gin headers won't have
// X-CPA-CLAUDE-1M because the request is internally constructed.
if !hasClaude1MHeader && auth != nil && auth.Attributes != nil {
if auth.Attributes["gitlab_duo_force_context_1m"] == "true" {
hasClaude1MHeader = true
}
}
// Merge extra betas from request body and request flags.
if len(extraBetas) > 0 || hasClaude1MHeader {
@@ -949,12 +959,9 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
// Collect built-in tool names from the authoritative fallback seed list and
// augment it with any typed built-ins present in the current request body.
builtinTools := helps.AugmentClaudeBuiltinToolRegistry(body, nil)
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
@@ -1463,182 +1470,6 @@ func countCacheControls(payload []byte) int {
return count
}
func parsePayloadObject(payload []byte) (map[string]any, bool) {
if len(payload) == 0 {
return nil, false
}
var root map[string]any
if err := json.Unmarshal(payload, &root); err != nil {
return nil, false
}
return root, true
}
func marshalPayloadObject(original []byte, root map[string]any) []byte {
if root == nil {
return original
}
out, err := json.Marshal(root)
if err != nil {
return original
}
return out
}
func asObject(v any) (map[string]any, bool) {
obj, ok := v.(map[string]any)
return obj, ok
}
func asArray(v any) ([]any, bool) {
arr, ok := v.([]any)
return arr, ok
}
func countCacheControlsMap(root map[string]any) int {
count := 0
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if tools, ok := asArray(root["tools"]); ok {
for _, item := range tools {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
}
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
last := -1
for idx, item := range arr {
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
last = idx
}
}
return last
}
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
for idx, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
delete(obj, "cache_control")
*excess--
}
}
}
func stripAllCacheControl(arr []any, excess *int) {
for _, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
func stripMessageCacheControl(messages []any, excess *int) {
for _, msg := range messages {
if *excess <= 0 {
return
}
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
@@ -1651,58 +1482,75 @@ func stripMessageCacheControl(messages []any, excess *int) {
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
original := payload
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
processBlock := func(path string, obj gjson.Result) {
cc := obj.Get("cache_control")
if !cc.Exists() {
return
}
if !cc.IsObject() {
seen5m = true
return
}
ttl := cc.Get("ttl")
if ttl.Type != gjson.String || ttl.String() != "1h" {
seen5m = true
return
}
if !seen5m {
return
}
ttlPath := path + ".cache_control.ttl"
updated, errDel := sjson.DeleteBytes(payload, ttlPath)
if errDel != nil {
return
}
payload = updated
modified = true
}
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item)
return true
})
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item)
return true
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
content := msg.Get("content")
if !content.IsArray() {
return true
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item)
return true
})
return true
})
}
if !modified {
return payload
return original
}
return marshalPayloadObject(payload, root)
return payload
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
@@ -1722,64 +1570,166 @@ func normalizeCacheControlTTL(payload []byte) []byte {
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
total := countCacheControlsMap(root)
total := countCacheControls(payload)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
var system []any
if arr, ok := asArray(root["system"]); ok {
system = arr
}
var tools []any
if arr, ok := asArray(root["tools"]); ok {
tools = arr
}
var messages []any
if arr, ok := asArray(root["messages"]); ok {
messages = arr
}
if len(system) > 0 {
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
lastIdx := -1
system.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
lastIdx := -1
tools.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(messages) > 0 {
stripMessageCacheControl(messages, &excess)
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
if excess <= 0 {
return false
}
content := msg.Get("content")
if !content.IsArray() {
return true
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(system) > 0 {
stripAllCacheControl(system, &excess)
system = gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripAllCacheControl(tools, &excess)
tools = gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
return marshalPayloadObject(payload, root)
return payload
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.

View File

@@ -739,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) {
for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} {
t.Run(builtin, func(t *testing.T) {
input := []byte(fmt.Sprintf(`{
"tools":[{"name":"Read"}],
"tool_choice":{"type":"tool","name":%q},
"messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}]
}`, builtin, builtin, builtin, builtin))
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin {
t.Fatalf("tool_choice.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
})
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -965,6 +994,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.
}
}
func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := normalizeCacheControlTTL(payload)
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
@@ -994,6 +1045,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T)
}
}
func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{
"tools": [

View File

@@ -4,9 +4,11 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
@@ -14,8 +16,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -98,10 +103,12 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, _ = sjson.SetBytes(translated, "stream", true)
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
@@ -114,6 +121,8 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err
}
e.applyHeaders(httpReq, accessToken, userID, domain)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string
if auth != nil {
@@ -160,11 +169,16 @@ func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
aggregatedBody, usageDetail, err := aggregateOpenAIChatCompletionStream(body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
reporter.publish(ctx, usageDetail)
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, aggregatedBody, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
@@ -341,3 +355,197 @@ func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID,
req.Header.Set("X-IDE-Version", "2.63.2")
req.Header.Set("X-Requested-With", "XMLHttpRequest")
}
type openAIChatStreamChoiceAccumulator struct {
Role string
ContentParts []string
ReasoningParts []string
FinishReason string
ToolCalls map[int]*openAIChatStreamToolCallAccumulator
ToolCallOrder []int
NativeFinishReason any
}
type openAIChatStreamToolCallAccumulator struct {
ID string
Type string
Name string
Arguments strings.Builder
}
func aggregateOpenAIChatCompletionStream(raw []byte) ([]byte, usage.Detail, error) {
lines := bytes.Split(raw, []byte("\n"))
var (
responseID string
model string
created int64
serviceTier string
systemFP string
usageDetail usage.Detail
choices = map[int]*openAIChatStreamChoiceAccumulator{}
choiceOrder []int
)
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(line[5:])
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if !gjson.ValidBytes(payload) {
continue
}
root := gjson.ParseBytes(payload)
if responseID == "" {
responseID = root.Get("id").String()
}
if model == "" {
model = root.Get("model").String()
}
if created == 0 {
created = root.Get("created").Int()
}
if serviceTier == "" {
serviceTier = root.Get("service_tier").String()
}
if systemFP == "" {
systemFP = root.Get("system_fingerprint").String()
}
if detail, ok := parseOpenAIStreamUsage(line); ok {
usageDetail = detail
}
for _, choiceResult := range root.Get("choices").Array() {
idx := int(choiceResult.Get("index").Int())
choice := choices[idx]
if choice == nil {
choice = &openAIChatStreamChoiceAccumulator{ToolCalls: map[int]*openAIChatStreamToolCallAccumulator{}}
choices[idx] = choice
choiceOrder = append(choiceOrder, idx)
}
delta := choiceResult.Get("delta")
if role := delta.Get("role").String(); role != "" {
choice.Role = role
}
if content := delta.Get("content").String(); content != "" {
choice.ContentParts = append(choice.ContentParts, content)
}
if reasoning := delta.Get("reasoning_content").String(); reasoning != "" {
choice.ReasoningParts = append(choice.ReasoningParts, reasoning)
}
if finishReason := choiceResult.Get("finish_reason").String(); finishReason != "" {
choice.FinishReason = finishReason
}
if nativeFinishReason := choiceResult.Get("native_finish_reason"); nativeFinishReason.Exists() {
choice.NativeFinishReason = nativeFinishReason.Value()
}
for _, toolCallResult := range delta.Get("tool_calls").Array() {
toolIdx := int(toolCallResult.Get("index").Int())
toolCall := choice.ToolCalls[toolIdx]
if toolCall == nil {
toolCall = &openAIChatStreamToolCallAccumulator{}
choice.ToolCalls[toolIdx] = toolCall
choice.ToolCallOrder = append(choice.ToolCallOrder, toolIdx)
}
if id := toolCallResult.Get("id").String(); id != "" {
toolCall.ID = id
}
if typ := toolCallResult.Get("type").String(); typ != "" {
toolCall.Type = typ
}
if name := toolCallResult.Get("function.name").String(); name != "" {
toolCall.Name = name
}
if args := toolCallResult.Get("function.arguments").String(); args != "" {
toolCall.Arguments.WriteString(args)
}
}
}
}
if responseID == "" && model == "" && len(choiceOrder) == 0 {
return nil, usageDetail, fmt.Errorf("codebuddy: streaming response did not contain any chat completion chunks")
}
response := map[string]any{
"id": responseID,
"object": "chat.completion",
"created": created,
"model": model,
"choices": make([]map[string]any, 0, len(choiceOrder)),
"usage": map[string]any{
"prompt_tokens": usageDetail.InputTokens,
"completion_tokens": usageDetail.OutputTokens,
"total_tokens": usageDetail.TotalTokens,
},
}
if serviceTier != "" {
response["service_tier"] = serviceTier
}
if systemFP != "" {
response["system_fingerprint"] = systemFP
}
for _, idx := range choiceOrder {
choice := choices[idx]
message := map[string]any{
"role": choice.Role,
"content": strings.Join(choice.ContentParts, ""),
}
if message["role"] == "" {
message["role"] = "assistant"
}
if len(choice.ReasoningParts) > 0 {
message["reasoning_content"] = strings.Join(choice.ReasoningParts, "")
}
if len(choice.ToolCallOrder) > 0 {
toolCalls := make([]map[string]any, 0, len(choice.ToolCallOrder))
for _, toolIdx := range choice.ToolCallOrder {
toolCall := choice.ToolCalls[toolIdx]
toolCallType := toolCall.Type
if toolCallType == "" {
toolCallType = "function"
}
arguments := toolCall.Arguments.String()
if arguments == "" {
arguments = "{}"
}
toolCalls = append(toolCalls, map[string]any{
"id": toolCall.ID,
"type": toolCallType,
"function": map[string]any{
"name": toolCall.Name,
"arguments": arguments,
},
})
}
message["tool_calls"] = toolCalls
}
finishReason := choice.FinishReason
if finishReason == "" {
finishReason = "stop"
}
choicePayload := map[string]any{
"index": idx,
"message": message,
"finish_reason": finishReason,
}
if choice.NativeFinishReason != nil {
choicePayload["native_finish_reason"] = choice.NativeFinishReason
}
response["choices"] = append(response["choices"].([]map[string]any), choicePayload)
}
out, err := json.Marshal(response)
if err != nil {
return nil, usageDetail, fmt.Errorf("codebuddy: failed to encode aggregated response: %w", err)
}
return out, usageDetail, nil
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
@@ -167,22 +168,63 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n"))
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range lines {
if !bytes.HasPrefix(line, dataTag) {
continue
}
line = bytes.TrimSpace(line[5:])
if gjson.GetBytes(line, "type").String() != "response.completed" {
eventData := bytes.TrimSpace(line[5:])
eventType := gjson.GetBytes(eventData, "type").String()
if eventType == "response.output_item.done" {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
continue
}
outputIndexResult := gjson.GetBytes(eventData, "output_index")
if outputIndexResult.Exists() {
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
} else {
outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw))
}
continue
}
if detail, ok := helps.ParseCodexUsage(line); ok {
if eventType != "response.completed" {
continue
}
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
completedData := eventData
outputResult := gjson.GetBytes(completedData, "response.output")
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
if shouldPatchOutput {
completedDataPatched := completedData
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`))
indexes := make([]int64, 0, len(outputItemsByIndex))
for idx := range outputItemsByIndex {
indexes = append(indexes, idx)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i] < indexes[j]
})
for _, idx := range indexes {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx])
}
for _, item := range outputItemsFallback {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item)
}
completedData = completedDataPatched
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}

View File

@@ -0,0 +1,46 @@
package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4-mini",
Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String()
if gotContent != "ok" {
t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload))
}
}

View File

@@ -734,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
}
switch setting.URL.Scheme {
case "socks5":
case "socks5", "socks5h":
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strings"
"sync"
"time"
@@ -17,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
@@ -40,7 +42,7 @@ const (
copilotEditorVersion = "vscode/1.107.0"
copilotPluginVersion = "copilot-chat/0.35.0"
copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-panel"
copilotOpenAIIntent = "conversation-edits"
copilotGitHubAPIVer = "2025-04-01"
)
@@ -126,6 +128,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
body = stripUnsupportedBetas(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
@@ -142,6 +145,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
body = applyGitHubCopilotResponsesDefaults(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
@@ -225,9 +229,10 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
data = normalizeGitHubCopilotReasoningField(data)
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: converted}
resp = cliproxyexecutor.Response{Payload: converted, Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
@@ -256,6 +261,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
body = stripUnsupportedBetas(body)
// Detect vision content before input normalization removes messages
hasVision := detectVisionContent(body)
@@ -272,6 +278,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
body = applyGitHubCopilotResponsesDefaults(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
@@ -378,7 +385,20 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
// Strip SSE "data: " prefix before reasoning field normalization,
// since normalizeGitHubCopilotReasoningField expects pure JSON.
// Re-wrap with the prefix afterward for the translator.
normalizedLine := bytes.Clone(line)
if bytes.HasPrefix(line, dataTag) {
sseData := bytes.TrimSpace(line[len(dataTag):])
if !bytes.Equal(sseData, []byte("[DONE]")) && gjson.ValidBytes(sseData) {
normalized := normalizeGitHubCopilotReasoningField(bytes.Clone(sseData))
if !bytes.Equal(normalized, sseData) {
normalizedLine = append(append([]byte(nil), dataTag...), normalized...)
}
}
}
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, normalizedLine, &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: bytes.Clone(chunks[i])}
@@ -400,9 +420,28 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
}, nil
}
// CountTokens is not supported for GitHub Copilot.
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
// CountTokens estimates token count locally using tiktoken, since the GitHub
// Copilot API does not expose a dedicated token counting endpoint.
func (e *GitHubCopilotExecutor) CountTokens(ctx context.Context, _ *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := helps.TokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: tokenizer init failed: %w", err)
}
count, err := helps.CountOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("github copilot executor: token counting failed: %w", err)
}
usageJSON := helps.BuildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: translatedUsage}, nil
}
// Refresh validates the GitHub token is still working.
@@ -491,46 +530,127 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, b
r.Header.Set("X-Request-Id", uuid.NewString())
initiator := "user"
if role := detectLastConversationRole(body); role == "assistant" || role == "tool" {
if isAgentInitiated(body) {
initiator = "agent"
}
r.Header.Set("X-Initiator", initiator)
}
func detectLastConversationRole(body []byte) string {
// isAgentInitiated determines whether the current request is agent-initiated
// (tool callbacks, continuations) rather than user-initiated (new user prompt).
//
// GitHub Copilot uses the X-Initiator header for billing:
// - "user" → consumes premium request quota
// - "agent" → free (tool loops, continuations)
//
// The challenge: Claude Code sends tool results as role:"user" messages with
// content type "tool_result". After translation to OpenAI format, the tool_result
// part becomes a separate role:"tool" message, but if the original Claude message
// also contained text content (e.g. skill invocations, attachment descriptions),
// a role:"user" message is emitted AFTER the tool message, making the last message
// appear user-initiated when it's actually part of an agent tool loop.
//
// VSCode Copilot Chat solves this with explicit flags (iterationNumber,
// isContinuation, subAgentInvocationId). Since CPA doesn't have these flags,
// we infer agent status by checking whether the conversation contains prior
// assistant/tool messages — if it does, the current request is a continuation.
//
// References:
// - opencode#8030, opencode#15824: same root cause and fix approach
// - vscode-copilot-chat: toolCallingLoop.ts (iterationNumber === 0)
// - pi-ai: github-copilot-headers.ts (last message role check)
func isAgentInitiated(body []byte) bool {
if len(body) == 0 {
return ""
return false
}
// Chat Completions API: check messages array
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
arr := messages.Array()
if len(arr) == 0 {
return false
}
lastRole := ""
for i := len(arr) - 1; i >= 0; i-- {
if role := arr[i].Get("role").String(); role != "" {
return role
if r := arr[i].Get("role").String(); r != "" {
lastRole = r
break
}
}
// If last message is assistant or tool, clearly agent-initiated.
if lastRole == "assistant" || lastRole == "tool" {
return true
}
// If last message is "user", check whether it contains tool results
// (indicating a tool-loop continuation) or if the preceding message
// is an assistant tool_use. This is more precise than checking for
// any prior assistant message, which would false-positive on genuine
// multi-turn follow-ups.
if lastRole == "user" {
// Check if the last user message contains tool_result content
lastContent := arr[len(arr)-1].Get("content")
if lastContent.Exists() && lastContent.IsArray() {
for _, part := range lastContent.Array() {
if part.Get("type").String() == "tool_result" {
return true
}
}
}
// Check if the second-to-last message is an assistant with tool_use
if len(arr) >= 2 {
prev := arr[len(arr)-2]
if prev.Get("role").String() == "assistant" {
prevContent := prev.Get("content")
if prevContent.Exists() && prevContent.IsArray() {
for _, part := range prevContent.Array() {
if part.Get("type").String() == "tool_use" {
return true
}
}
}
}
}
}
return false
}
// Responses API: check input array
if inputs := gjson.GetBytes(body, "input"); inputs.Exists() && inputs.IsArray() {
arr := inputs.Array()
for i := len(arr) - 1; i >= 0; i-- {
item := arr[i]
if len(arr) == 0 {
return false
}
// Most Responses input items carry a top-level role.
if role := item.Get("role").String(); role != "" {
return role
// Check last item
last := arr[len(arr)-1]
if role := last.Get("role").String(); role == "assistant" {
return true
}
switch last.Get("type").String() {
case "function_call", "function_call_arguments", "computer_call":
return true
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
return true
}
// If last item is user-role, check for prior non-user items
for _, item := range arr {
if role := item.Get("role").String(); role == "assistant" {
return true
}
switch item.Get("type").String() {
case "function_call", "function_call_arguments", "computer_call":
return "assistant"
case "function_call_output", "function_call_response", "tool_result", "computer_call_output":
return "tool"
case "function_call", "function_call_output", "function_call_response",
"function_call_arguments", "computer_call", "computer_call_output":
return true
}
}
}
return ""
return false
}
// detectVisionContent checks if the request body contains vision/image content.
@@ -572,6 +692,85 @@ func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte
return body
}
// copilotUnsupportedBetas lists beta headers that are Anthropic-specific and
// must not be forwarded to GitHub Copilot. The context-1m beta enables 1M
// context on Anthropic's API, but Copilot's Claude models are limited to
// ~128K-200K. Passing it through would not enable 1M on Copilot, but stripping
// it from the translated body avoids confusing downstream translators.
var copilotUnsupportedBetas = []string{
"context-1m-2025-08-07",
}
// stripUnsupportedBetas removes Anthropic-specific beta entries from the
// translated request body. In OpenAI format the betas may appear under
// "metadata.betas" or a top-level "betas" array; in Claude format they sit at
// "betas". This function checks all known locations.
func stripUnsupportedBetas(body []byte) []byte {
betaPaths := []string{"betas", "metadata.betas"}
for _, path := range betaPaths {
arr := gjson.GetBytes(body, path)
if !arr.Exists() || !arr.IsArray() {
continue
}
var filtered []string
changed := false
for _, item := range arr.Array() {
beta := item.String()
if isCopilotUnsupportedBeta(beta) {
changed = true
continue
}
filtered = append(filtered, beta)
}
if !changed {
continue
}
if len(filtered) == 0 {
body, _ = sjson.DeleteBytes(body, path)
} else {
body, _ = sjson.SetBytes(body, path, filtered)
}
}
return body
}
func isCopilotUnsupportedBeta(beta string) bool {
return slices.Contains(copilotUnsupportedBetas, beta)
}
// normalizeGitHubCopilotReasoningField maps Copilot's non-standard
// 'reasoning_text' field to the standard OpenAI 'reasoning_content' field
// that the SDK translator expects. This handles both streaming deltas
// (choices[].delta.reasoning_text) and non-streaming messages
// (choices[].message.reasoning_text). The field is only renamed when
// 'reasoning_content' is absent or null, preserving standard responses.
// All choices are processed to support n>1 requests.
func normalizeGitHubCopilotReasoningField(data []byte) []byte {
choices := gjson.GetBytes(data, "choices")
if !choices.Exists() || !choices.IsArray() {
return data
}
for i := range choices.Array() {
// Non-streaming: choices[i].message.reasoning_text
msgRT := fmt.Sprintf("choices.%d.message.reasoning_text", i)
msgRC := fmt.Sprintf("choices.%d.message.reasoning_content", i)
if rt := gjson.GetBytes(data, msgRT); rt.Exists() && rt.String() != "" {
if rc := gjson.GetBytes(data, msgRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
data, _ = sjson.SetBytes(data, msgRC, rt.String())
}
}
// Streaming: choices[i].delta.reasoning_text
deltaRT := fmt.Sprintf("choices.%d.delta.reasoning_text", i)
deltaRC := fmt.Sprintf("choices.%d.delta.reasoning_content", i)
if rt := gjson.GetBytes(data, deltaRT); rt.Exists() && rt.String() != "" {
if rc := gjson.GetBytes(data, deltaRC); !rc.Exists() || rc.Type == gjson.Null || rc.String() == "" {
data, _ = sjson.SetBytes(data, deltaRC, rt.String())
}
}
}
return data
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
if sourceFormat.String() == "openai-response" {
return true
@@ -596,12 +795,7 @@ func lookupGitHubCopilotStaticModelInfo(model string) *registry.ModelInfo {
}
func containsEndpoint(endpoints []string, endpoint string) bool {
for _, item := range endpoints {
if item == endpoint {
return true
}
}
return false
return slices.Contains(endpoints, endpoint)
}
// flattenAssistantContent converts assistant message content from array format
@@ -856,6 +1050,32 @@ func stripGitHubCopilotResponsesUnsupportedFields(body []byte) []byte {
return body
}
// applyGitHubCopilotResponsesDefaults sets required fields for the Responses API
// that both vscode-copilot-chat and pi-ai always include.
//
// References:
// - vscode-copilot-chat: src/platform/endpoint/node/responsesApi.ts
// - pi-ai (badlogic/pi-mono): packages/ai/src/providers/openai-responses.ts
func applyGitHubCopilotResponsesDefaults(body []byte) []byte {
// store: false — prevents request/response storage
if !gjson.GetBytes(body, "store").Exists() {
body, _ = sjson.SetBytes(body, "store", false)
}
// include: ["reasoning.encrypted_content"] — enables reasoning content
// reuse across turns, avoiding redundant computation
if !gjson.GetBytes(body, "include").Exists() {
body, _ = sjson.SetRawBytes(body, "include", []byte(`["reasoning.encrypted_content"]`))
}
// If reasoning.effort is set but reasoning.summary is not, default to "auto"
if gjson.GetBytes(body, "reasoning.effort").Exists() && !gjson.GetBytes(body, "reasoning.summary").Exists() {
body, _ = sjson.SetBytes(body, "reasoning.summary", "auto")
}
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
@@ -1406,6 +1626,21 @@ func FetchGitHubCopilotModels(ctx context.Context, auth *cliproxyauth.Auth, cfg
m.MaxCompletionTokens = defaultCopilotMaxCompletionTokens
}
// Override with real limits from the Copilot API when available.
// The API returns per-account limits (individual vs business) under
// capabilities.limits, which are more accurate than our static
// fallback values. We use max_prompt_tokens as ContextLength because
// that's the hard limit the Copilot API enforces on prompt size —
// exceeding it triggers "prompt token count exceeds the limit" errors.
if limits := entry.Limits(); limits != nil {
if limits.MaxPromptTokens > 0 {
m.ContextLength = limits.MaxPromptTokens
}
if limits.MaxOutputTokens > 0 {
m.MaxCompletionTokens = limits.MaxOutputTokens
}
}
models = append(models, m)
}

View File

@@ -1,11 +1,14 @@
package executor
import (
"context"
"net/http"
"strings"
"testing"
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -72,7 +75,7 @@ func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
}
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
t.Parallel()
// Not parallel: shares global model registry with DynamicRegistryWinsOverStatic.
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
t.Fatal("expected responses-only registry model to use /responses")
}
@@ -82,7 +85,7 @@ func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing
}
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
t.Parallel()
// Not parallel: mutates global model registry, conflicts with RegistryResponsesOnlyModel.
reg := registry.GetGlobalRegistry()
clientID := "github-copilot-test-client"
@@ -251,14 +254,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
if gjson.GetBytes(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.GetBytes(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
if gjson.GetBytes(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.GetBytes(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
if gjson.GetBytes(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.GetBytes(out, "content.0.text").String())
}
}
@@ -266,14 +269,14 @@ func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *test
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
if gjson.GetBytes(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.GetBytes(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
if gjson.GetBytes(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.GetBytes(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
if gjson.GetBytes(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.GetBytes(out, "stop_reason").String())
}
}
@@ -282,18 +285,24 @@ func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
if len(created) == 0 || !strings.Contains(string(created[0]), "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
var joinedDelta string
for _, d := range delta {
joinedDelta += string(d)
}
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
var joinedCompleted string
for _, c := range completed {
joinedCompleted += string(c)
}
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
@@ -312,15 +321,17 @@ func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
}
}
func TestApplyHeaders_XInitiator_UserWhenLastRoleIsUser(t *testing.T) {
func TestApplyHeaders_XInitiator_AgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Last role governs the initiator decision.
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
// When the last role is "user" and the message contains tool_result content,
// the request is a continuation (e.g. Claude tool result translated to a
// synthetic user message). Should be "agent".
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (last user contains tool_result)", got)
}
}
@@ -328,10 +339,11 @@ func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// When the last message has role "tool", it's clearly agent-initiated.
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
t.Fatalf("X-Initiator = %q, want agent (last role is tool)", got)
}
}
@@ -346,14 +358,15 @@ func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
}
}
func TestApplyHeaders_XInitiator_InputArrayLastUserMessage(t *testing.T) {
func TestApplyHeaders_XInitiator_InputArrayAgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Responses API: last item is user-role but history contains assistant → agent.
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (last role is user)", got)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (history has assistant)", got)
}
}
@@ -368,6 +381,33 @@ func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T)
}
}
func TestApplyHeaders_XInitiator_UserInMultiTurnNoTools(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Genuine multi-turn: user → assistant (plain text) → user follow-up.
// No tool messages → should be "user" (not a false-positive).
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (genuine multi-turn, no tools)", got)
}
}
func TestApplyHeaders_XInitiator_UserFollowUpAfterToolHistory(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// User follow-up after a completed tool-use conversation.
// The last message is a genuine user question — should be "user", not "agent".
// This aligns with opencode's behavior: only active tool loops are agent-initiated.
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user (genuine follow-up after tool history)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
@@ -414,3 +454,364 @@ func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Fatal("expected no vision content when messages field is absent")
}
}
// --- Tests for applyGitHubCopilotResponsesDefaults ---
func TestApplyGitHubCopilotResponsesDefaults_SetsAllDefaults(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello","reasoning":{"effort":"medium"}}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != false {
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
}
inc := gjson.GetBytes(got, "include")
if !inc.IsArray() || inc.Array()[0].String() != "reasoning.encrypted_content" {
t.Fatalf("include = %s, want [\"reasoning.encrypted_content\"]", inc.Raw)
}
if gjson.GetBytes(got, "reasoning.summary").String() != "auto" {
t.Fatalf("reasoning.summary = %q, want auto", gjson.GetBytes(got, "reasoning.summary").String())
}
}
func TestApplyGitHubCopilotResponsesDefaults_DoesNotOverrideExisting(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello","store":true,"include":["other"],"reasoning":{"effort":"high","summary":"concise"}}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != true {
t.Fatalf("store should not be overridden, got %s", gjson.GetBytes(got, "store").Raw)
}
if gjson.GetBytes(got, "include").Array()[0].String() != "other" {
t.Fatalf("include should not be overridden, got %s", gjson.GetBytes(got, "include").Raw)
}
if gjson.GetBytes(got, "reasoning.summary").String() != "concise" {
t.Fatalf("reasoning.summary should not be overridden, got %q", gjson.GetBytes(got, "reasoning.summary").String())
}
}
func TestApplyGitHubCopilotResponsesDefaults_NoReasoningEffort(t *testing.T) {
t.Parallel()
body := []byte(`{"input":"hello"}`)
got := applyGitHubCopilotResponsesDefaults(body)
if gjson.GetBytes(got, "store").Bool() != false {
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
}
// reasoning.summary should NOT be set when reasoning.effort is absent
if gjson.GetBytes(got, "reasoning.summary").Exists() {
t.Fatalf("reasoning.summary should not be set when reasoning.effort is absent, got %q", gjson.GetBytes(got, "reasoning.summary").String())
}
}
// --- Tests for normalizeGitHubCopilotReasoningField ---
func TestNormalizeReasoningField_NonStreaming(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"content":"hello","reasoning_text":"I think..."}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
if rc != "I think..." {
t.Fatalf("reasoning_content = %q, want %q", rc, "I think...")
}
}
func TestNormalizeReasoningField_Streaming(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"delta":{"reasoning_text":"thinking delta"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.delta.reasoning_content").String()
if rc != "thinking delta" {
t.Fatalf("reasoning_content = %q, want %q", rc, "thinking delta")
}
}
func TestNormalizeReasoningField_PreservesExistingReasoningContent(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"reasoning_text":"old","reasoning_content":"existing"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
if rc != "existing" {
t.Fatalf("reasoning_content = %q, want %q (should not overwrite)", rc, "existing")
}
}
func TestNormalizeReasoningField_MultiChoice(t *testing.T) {
t.Parallel()
data := []byte(`{"choices":[{"message":{"reasoning_text":"thought-0"}},{"message":{"reasoning_text":"thought-1"}}]}`)
got := normalizeGitHubCopilotReasoningField(data)
rc0 := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
rc1 := gjson.GetBytes(got, "choices.1.message.reasoning_content").String()
if rc0 != "thought-0" {
t.Fatalf("choices[0].reasoning_content = %q, want %q", rc0, "thought-0")
}
if rc1 != "thought-1" {
t.Fatalf("choices[1].reasoning_content = %q, want %q", rc1, "thought-1")
}
}
func TestNormalizeReasoningField_NoChoices(t *testing.T) {
t.Parallel()
data := []byte(`{"id":"chatcmpl-123"}`)
got := normalizeGitHubCopilotReasoningField(data)
if string(got) != string(data) {
t.Fatalf("expected no change, got %s", string(got))
}
}
func TestApplyHeaders_OpenAIIntentValue(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("Openai-Intent"); got != "conversation-edits" {
t.Fatalf("Openai-Intent = %q, want conversation-edits", got)
}
}
// --- Tests for CountTokens (local tiktoken estimation) ---
func TestCountTokens_ReturnsPositiveCount(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}`)
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-4o",
Payload: body,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
if len(resp.Payload) == 0 {
t.Fatal("CountTokens() returned empty payload")
}
// The response should contain a positive token count.
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
if tokens <= 0 {
t.Fatalf("expected positive token count, got %d", tokens)
}
}
func TestCountTokens_ClaudeSourceFormatTranslates(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
body := []byte(`{"model":"claude-sonnet-4","messages":[{"role":"user","content":"Tell me a joke"}],"max_tokens":1024}`)
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "claude-sonnet-4",
Payload: body,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
// Claude source format → should get input_tokens in response
inputTokens := gjson.GetBytes(resp.Payload, "input_tokens").Int()
if inputTokens <= 0 {
// Fallback: check usage.prompt_tokens (depends on translator registration)
promptTokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
if promptTokens <= 0 {
t.Fatalf("expected positive token count, got payload: %s", resp.Payload)
}
}
}
func TestCountTokens_EmptyPayload(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
Model: "gpt-4o",
Payload: []byte(`{"model":"gpt-4o","messages":[]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
if err != nil {
t.Fatalf("CountTokens() error: %v", err)
}
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
// Empty messages should return 0 tokens.
if tokens != 0 {
t.Fatalf("expected 0 tokens for empty messages, got %d", tokens)
}
}
func TestStripUnsupportedBetas_RemovesContext1M(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","betas":["interleaved-thinking-2025-05-14","context-1m-2025-08-07","claude-code-20250219"],"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "betas")
if !betas.Exists() {
t.Fatal("betas field should still exist after stripping")
}
for _, item := range betas.Array() {
if item.String() == "context-1m-2025-08-07" {
t.Fatal("context-1m-2025-08-07 should have been stripped")
}
}
// Other betas should be preserved
found := false
for _, item := range betas.Array() {
if item.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
if !found {
t.Fatal("other betas should be preserved")
}
}
func TestStripUnsupportedBetas_NoBetasField(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"gpt-4o","messages":[]}`)
result := stripUnsupportedBetas(body)
// Should be unchanged
if string(result) != string(body) {
t.Fatalf("body should be unchanged when no betas field exists, got %s", string(result))
}
}
func TestStripUnsupportedBetas_MetadataBetas(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","metadata":{"betas":["context-1m-2025-08-07","other-beta"]},"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "metadata.betas")
if !betas.Exists() {
t.Fatal("metadata.betas field should still exist after stripping")
}
for _, item := range betas.Array() {
if item.String() == "context-1m-2025-08-07" {
t.Fatal("context-1m-2025-08-07 should have been stripped from metadata.betas")
}
}
if betas.Array()[0].String() != "other-beta" {
t.Fatal("other betas in metadata.betas should be preserved")
}
}
func TestStripUnsupportedBetas_AllBetasStripped(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"claude-opus-4.6","betas":["context-1m-2025-08-07"],"messages":[]}`)
result := stripUnsupportedBetas(body)
betas := gjson.GetBytes(result, "betas")
if betas.Exists() {
t.Fatal("betas field should be deleted when all betas are stripped")
}
}
func TestCopilotModelEntry_Limits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
capabilities map[string]any
wantNil bool
wantPrompt int
wantOutput int
wantContext int
}{
{
name: "nil capabilities",
capabilities: nil,
wantNil: true,
},
{
name: "no limits key",
capabilities: map[string]any{"family": "claude-opus-4.6"},
wantNil: true,
},
{
name: "limits is not a map",
capabilities: map[string]any{"limits": "invalid"},
wantNil: true,
},
{
name: "all zero values",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(0),
"max_prompt_tokens": float64(0),
"max_output_tokens": float64(0),
},
},
wantNil: true,
},
{
name: "individual account limits (128K prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(144000),
"max_prompt_tokens": float64(128000),
"max_output_tokens": float64(64000),
},
},
wantNil: false,
wantPrompt: 128000,
wantOutput: 64000,
wantContext: 144000,
},
{
name: "business account limits (168K prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_context_window_tokens": float64(200000),
"max_prompt_tokens": float64(168000),
"max_output_tokens": float64(32000),
},
},
wantNil: false,
wantPrompt: 168000,
wantOutput: 32000,
wantContext: 200000,
},
{
name: "partial limits (only prompt)",
capabilities: map[string]any{
"limits": map[string]any{
"max_prompt_tokens": float64(128000),
},
},
wantNil: false,
wantPrompt: 128000,
wantOutput: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
entry := copilotauth.CopilotModelEntry{
ID: "claude-opus-4.6",
Capabilities: tt.capabilities,
}
limits := entry.Limits()
if tt.wantNil {
if limits != nil {
t.Fatalf("expected nil limits, got %+v", limits)
}
return
}
if limits == nil {
t.Fatal("expected non-nil limits, got nil")
}
if limits.MaxPromptTokens != tt.wantPrompt {
t.Errorf("MaxPromptTokens = %d, want %d", limits.MaxPromptTokens, tt.wantPrompt)
}
if limits.MaxOutputTokens != tt.wantOutput {
t.Errorf("MaxOutputTokens = %d, want %d", limits.MaxOutputTokens, tt.wantOutput)
}
if tt.wantContext > 0 && limits.MaxContextWindowTokens != tt.wantContext {
t.Errorf("MaxContextWindowTokens = %d, want %d", limits.MaxContextWindowTokens, tt.wantContext)
}
})
}
}

View File

@@ -0,0 +1,38 @@
package helps
import "github.com/tidwall/gjson"
var defaultClaudeBuiltinToolNames = []string{
"web_search",
"code_execution",
"text_editor",
"computer",
}
func newClaudeBuiltinToolRegistry() map[string]bool {
registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames))
for _, name := range defaultClaudeBuiltinToolNames {
registry[name] = true
}
return registry
}
func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool {
if registry == nil {
registry = newClaudeBuiltinToolRegistry()
}
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return registry
}
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "" {
return true
}
if name := tool.Get("name").String(); name != "" {
registry[name] = true
}
return true
})
return registry
}

View File

@@ -0,0 +1,32 @@
package helps
import "testing"
func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry(nil, nil)
for _, name := range defaultClaudeBuiltinToolNames {
if !registry[name] {
t.Fatalf("default builtin %q missing from fallback registry", name)
}
}
}
func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry([]byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"type": "custom_builtin_20250401", "name": "special_builtin"},
{"name": "Read"}
]
}`), nil)
if !registry["web_search"] {
t.Fatal("expected default typed builtin web_search in registry")
}
if !registry["special_builtin"] {
t.Fatal("expected typed builtin from body to be added to registry")
}
if registry["Read"] {
t.Fatal("expected untyped custom tool to stay out of builtin registry")
}
}

View File

@@ -298,6 +298,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
// Ensure we record the request if no usage chunk was ever seen
reporter.EnsurePublished(ctx)

View File

@@ -172,32 +172,101 @@ func timeUntilNextDay() time.Duration {
return tomorrow.Sub(now)
}
// ensureQwenSystemMessage prepends a default system message if none exists in "messages".
// ensureQwenSystemMessage ensures the request has a single system message at the beginning.
// It always injects the default system prompt and merges any user-provided system messages
// into the injected system message content to satisfy Qwen's strict message ordering rules.
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
messages := gjson.GetBytes(payload, "messages")
if messages.Exists() && messages.IsArray() {
var buf bytes.Buffer
buf.WriteByte('[')
buf.Write(qwenDefaultSystemMessage)
for _, msg := range messages.Array() {
buf.WriteByte(',')
buf.WriteString(msg.Raw)
isInjectedSystemPart := func(part gjson.Result) bool {
if !part.Exists() || !part.IsObject() {
return false
}
buf.WriteByte(']')
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
if errSet != nil {
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
if !strings.EqualFold(part.Get("type").String(), "text") {
return false
}
return updated, nil
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
return false
}
text := part.Get("text").String()
return text == "" || text == "You are Qwen Code."
}
var buf bytes.Buffer
buf.WriteByte('[')
buf.Write(qwenDefaultSystemMessage)
buf.WriteByte(']')
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
var systemParts []any
if defaultParts.Exists() && defaultParts.IsArray() {
for _, part := range defaultParts.Array() {
systemParts = append(systemParts, part.Value())
}
}
if len(systemParts) == 0 {
systemParts = append(systemParts, map[string]any{
"type": "text",
"text": "You are Qwen Code.",
"cache_control": map[string]any{
"type": "ephemeral",
},
})
}
appendSystemContent := func(content gjson.Result) {
makeTextPart := func(text string) map[string]any {
return map[string]any{
"type": "text",
"text": text,
}
}
if !content.Exists() || content.Type == gjson.Null {
return
}
if content.IsArray() {
for _, part := range content.Array() {
if part.Type == gjson.String {
systemParts = append(systemParts, makeTextPart(part.String()))
continue
}
if isInjectedSystemPart(part) {
continue
}
systemParts = append(systemParts, part.Value())
}
return
}
if content.Type == gjson.String {
systemParts = append(systemParts, makeTextPart(content.String()))
return
}
if content.IsObject() {
if isInjectedSystemPart(content) {
return
}
systemParts = append(systemParts, content.Value())
return
}
systemParts = append(systemParts, makeTextPart(content.String()))
}
messages := gjson.GetBytes(payload, "messages")
var nonSystemMessages []any
if messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if strings.EqualFold(msg.Get("role").String(), "system") {
appendSystemContent(msg.Get("content"))
continue
}
nonSystemMessages = append(nonSystemMessages, msg.Value())
}
}
newMessages := make([]any, 0, 1+len(nonSystemMessages))
newMessages = append(newMessages, map[string]any{
"role": "system",
"content": systemParts,
})
newMessages = append(newMessages, nonSystemMessages...)
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
if errSet != nil {
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
}
return updated, nil
}

View File

@@ -4,6 +4,7 @@ import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
)
func TestQwenExecutorParseSuffix(t *testing.T) {
@@ -28,3 +29,123 @@ func TestQwenExecutorParseSuffix(t *testing.T) {
})
}
}
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
payload := []byte(`{
"model": "qwen3.6-plus",
"stream": true,
"messages": [
{ "role": "system", "content": "ABCDEFG" },
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
if msgs[0].Get("role").String() != "system" {
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
}
parts := msgs[0].Get("content").Array()
if len(parts) != 2 {
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
}
if parts[0].Get("text").String() != "You are Qwen Code." || parts[0].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
}
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
}
if msgs[1].Get("role").String() != "user" {
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
}
}
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
parts := msgs[0].Get("content").Array()
if len(parts) != 2 {
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
}
if parts[1].Get("text").String() != "ABCDEFG" {
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
}
}
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
if msgs[0].Get("role").String() != "system" {
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
}
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
}
if msgs[1].Get("role").String() != "user" {
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
}
}
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
payload := []byte(`{
"messages": [
{ "role": "system", "content": "A" },
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
{ "role": "system", "content": "B" }
]
}`)
out, err := ensureQwenSystemMessage(payload)
if err != nil {
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
}
msgs := gjson.GetBytes(out, "messages").Array()
if len(msgs) != 2 {
t.Fatalf("messages length = %d, want 2", len(msgs))
}
parts := msgs[0].Get("content").Array()
if len(parts) != 3 {
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
}
if parts[1].Get("text").String() != "A" {
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
}
if parts[2].Get("text").String() != "B" {
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
}
}

View File

@@ -32,16 +32,24 @@ type GitTokenStore struct {
repoDir string
configDir string
remote string
branch string
username string
password string
lastGC time.Time
}
type resolvedRemoteBranch struct {
name plumbing.ReferenceName
hash plumbing.Hash
}
// NewGitTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record.
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
return &GitTokenStore{
remote: remote,
branch: strings.TrimSpace(branch),
username: username,
password: password,
}
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create repo dir: %w", errMk)
}
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
if s.branch != "" {
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
_ = os.RemoveAll(gitDir)
repo, errInit := git.PlainInit(repoDir, false)
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: init empty repo: %w", errInit)
}
if s.branch != "" {
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
}
}
if _, errRemote := repo.Remote("origin"); errRemote != nil {
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
Name: "origin",
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: worktree: %w", errWorktree)
}
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
if s.branch != "" {
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
s.dirLock.Unlock()
return errCheckout
}
} else {
// When branch is unset, ensure the working tree follows the remote default branch
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
if !shouldFallbackToCurrentBranch(repo, err) {
s.dirLock.Unlock()
return fmt.Errorf("git token store: checkout remote default: %w", err)
}
}
}
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
if s.branch != "" {
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if errPull := worktree.Pull(pullOpts); errPull != nil {
switch {
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
errors.Is(errPull, git.ErrUnstagedChanges),
errors.Is(errPull, git.ErrNonFastForwardUpdate):
// Ignore clean syncs, local edits, and remote divergence—local changes win.
case errors.Is(errPull, transport.ErrAuthenticationRequired),
errors.Is(errPull, plumbing.ErrReferenceNotFound),
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
// Ignore authentication prompts and empty remote references on initial sync.
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
if s.branch != "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
}
// Ignore missing references only when following the remote default branch.
default:
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
@@ -554,6 +596,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
return rel, nil
}
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
branchRefName := plumbing.NewBranchReferenceName(s.branch)
headRef, errHead := repo.Head()
switch {
case errHead == nil && headRef.Name() == branchRefName:
return nil
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
return fmt.Errorf("git token store: get head: %w", errHead)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
return nil
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
}
return nil
}
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
remoteRef, err := repo.Reference(remoteRefName, true)
if errors.Is(err, plumbing.ErrReferenceNotFound) {
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
return fmt.Errorf("sync remote refs: %w", errSync)
}
remoteRef, err = repo.Reference(remoteRefName, true)
}
if err != nil {
return err
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
return err
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[s.branch]; !ok {
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
}
cfg.Branches[s.branch].Remote = "origin"
cfg.Branches[s.branch].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
return err
}
return nil
}
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
if err := syncRemoteReferences(repo, authMethod); err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
}
remote, err := repo.Remote("origin")
if err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
}
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
if err != nil {
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
}
for _, r := range refs {
if r.Name() == plumbing.HEAD {
if r.Type() == plumbing.SymbolicReference {
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
s := r.String()
if idx := strings.Index(s, "->"); idx != -1 {
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
}
}
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
for _, r := range refs {
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
}
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
}
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
if err != nil || ref.Type() != plumbing.SymbolicReference {
return resolvedRemoteBranch{}, false
}
target, ok := normalizeRemoteBranchReference(ref.Target())
if !ok {
return resolvedRemoteBranch{}, false
}
return resolvedRemoteBranch{name: target}, true
}
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
switch {
case strings.HasPrefix(name.String(), "refs/heads/"):
return name, true
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
default:
return "", false
}
}
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
return false
}
_, headErr := repo.Head()
return headErr == nil
}
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
// the remote branch.
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
if err != nil {
return err
}
branchRefName := resolved.name
// If HEAD already points to the desired branch, nothing to do.
headRef, errHead := repo.Head()
if errHead == nil && headRef.Name() == branchRefName {
return nil
}
// If local branch exists, attempt a checkout
if _, err := repo.Reference(branchRefName, true); err == nil {
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
}
return nil
}
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
hash := resolved.hash
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
hash = remoteRef.Hash()
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
}
if hash == plumbing.ZeroHash {
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[branchShort]; !ok {
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
}
cfg.Branches[branchShort].Remote = "origin"
cfg.Branches[branchShort].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
repoDir := s.repoDirSnapshot()
if repoDir == "" {
@@ -619,7 +847,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
return errRewrite
}
s.maybeRunGC(repo)
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
if s.branch != "" {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
} else {
// When branch is unset, pin push to the currently checked-out branch.
if headRef, err := repo.Head(); err == nil {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
}
}
if err = repo.Push(pushOpts); err != nil {
if errors.Is(err, git.NoErrAlreadyUpToDate) {
return nil
}

View File

@@ -0,0 +1,585 @@
package store
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-git/go-git/v6"
gitconfig "github.com/go-git/go-git/v6/config"
"github.com/go-git/go-git/v6/plumbing"
"github.com/go-git/go-git/v6/plumbing/object"
)
type testBranchSpec struct {
name string
contents string
}
func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "release/2026")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "missing-branch")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
err := store.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch")
}
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch")
reopened.SetBaseDir(baseDir)
err := reopened.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch")
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
branch := "feature/gemini-fix"
store := NewGitTokenStore(remoteDir, "", "", branch)
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch)
assertRemoteBranchExistsWithCommit(t, remoteDir, branch)
assertRemoteBranchDoesNotExist(t, remoteDir, "master")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
reopened := NewGitTokenStore(remoteDir, "", "", "develop")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil {
t.Fatalf("write local branch marker: %v", err)
}
reopened.mu.Lock()
err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt")
reopened.mu.Unlock()
if err != nil {
t.Fatalf("commitAndPushLocked: %v", err)
}
assertRepositoryHeadBranch(t, workspaceDir, "develop")
assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n")
assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release")
reopened := NewGitTokenStore(remoteDir, "", "", "release/2026")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
}
func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
// First store pins to develop and prepares local workspace
storePinned := NewGitTokenStore(remoteDir, "", "", "develop")
storePinned.SetBaseDir(baseDir)
if err := storePinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
// Second store has branch unset and should reset local workspace to remote default (master)
storeDefault := NewGitTokenStore(remoteDir, "", "", "")
storeDefault.SetBaseDir(baseDir)
if err := storeDefault.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default: %v", err)
}
// Local HEAD should now follow remote default (master)
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master")
// Make a local change and push using the store with branch unset; push should update remote master
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil {
t.Fatalf("write local master marker: %v", err)
}
storeDefault.mu.Lock()
if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil {
storeDefault.mu.Unlock()
t.Fatalf("commitAndPushLocked: %v", err)
}
storeDefault.mu.Unlock()
assertRemoteBranchContents(t, remoteDir, "master", "local master update\n")
}
func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "main", contents: "remote main branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
setRemoteHeadBranch(t, remoteDir, "main")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main")
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository after remote default rename: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "main")
}
func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
pinned := NewGitTokenStore(remoteDir, "", "", "develop")
pinned.SetBaseDir(baseDir)
if err := pinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
http.Error(w, "auth required", http.StatusUnauthorized)
}))
defer authServer.Close()
repo, err := git.PlainOpen(filepath.Join(root, "workspace"))
if err != nil {
t.Fatalf("open workspace repo: %v", err)
}
cfg, err := repo.Config()
if err != nil {
t.Fatalf("read repo config: %v", err)
}
cfg.Remotes["origin"].URLs = []string{authServer.URL}
if err := repo.SetConfig(cfg); err != nil {
t.Fatalf("set repo config: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default branch fallback: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop")
}
func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string {
t.Helper()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
seedDir := filepath.Join(root, "seed")
seedRepo, err := git.PlainInit(seedDir, false)
if err != nil {
t.Fatalf("init seed repo: %v", err)
}
if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set seed HEAD: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
defaultSpec, ok := findBranchSpec(branches, defaultBranch)
if !ok {
t.Fatalf("missing default branch spec for %q", defaultBranch)
}
commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch")
for _, branch := range branches {
if branch.name == defaultBranch {
continue
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil {
t.Fatalf("checkout default branch %s: %v", defaultBranch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch.name, err)
}
commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name)
}
if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil {
t.Fatalf("create origin remote: %v", err)
}
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")},
}); err != nil {
t.Fatalf("push seed branches: %v", err)
}
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set remote HEAD: %v", err)
}
return remoteDir
}
func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) {
t.Helper()
if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil {
t.Fatalf("write branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Add("branch.txt"); err != nil {
t.Fatalf("add branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Commit(message, &git.CommitOptions{
Author: &object.Signature{
Name: "CLIProxyAPI",
Email: "cliproxy@local",
When: time.Unix(1711929600, 0),
},
}); err != nil {
t.Fatalf("commit branch marker for %s: %v", branch.name, err)
}
}
func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil {
t.Fatalf("checkout branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err)
}
}
func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil {
t.Fatalf("checkout master before creating %s: %v", branch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err)
}
}
func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) {
for _, branch := range branches {
if branch.name == name {
return branch, true
}
}
return testBranchSpec{}, false
}
func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt"))
if err != nil {
t.Fatalf("read branch marker: %v", err)
}
if got := string(contents); got != wantContents {
t.Fatalf("branch marker contents = %q, want %q", got, wantContents)
}
}
func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
}
func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
head, err := remoteRepo.Reference(plumbing.HEAD, false)
if err != nil {
t.Fatalf("read remote HEAD: %v", err)
}
if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("remote HEAD target = %s, want %s", got, want)
}
}
func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil {
t.Fatalf("set remote HEAD to %s: %v", branch, err)
}
}
func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
if got := ref.Hash(); got == plumbing.ZeroHash {
t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got)
}
}
func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil {
t.Fatalf("remote branch %s exists, want missing", branch)
} else if err != plumbing.ErrReferenceNotFound {
t.Fatalf("read remote branch %s: %v", branch, err)
}
}
func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
commit, err := remoteRepo.CommitObject(ref.Hash())
if err != nil {
t.Fatalf("read remote branch %s commit: %v", branch, err)
}
tree, err := commit.Tree()
if err != nil {
t.Fatalf("read remote branch %s tree: %v", branch, err)
}
file, err := tree.File("branch.txt")
if err != nil {
t.Fatalf("read remote branch %s file: %v", branch, err)
}
contents, err := file.Contents()
if err != nil {
t.Fatalf("read remote branch %s contents: %v", branch, err)
}
if contents != wantContents {
t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents)
}
}

View File

@@ -174,7 +174,8 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
// Ensure the request satisfies Claude constraints:
// 1) Determine effective max_tokens (request overrides model default)
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
// 3) If the adjusted budget falls below the model minimum, try raising max_tokens
// (clamped to MaxCompletionTokens); disable thinking if constraints are unsatisfiable
// 4) If max_tokens came from model default, write it back into the request
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
@@ -193,8 +194,28 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
// If enforcing the max_tokens constraint would push the budget below the model minimum,
// leave the request unchanged.
// Enforcing budget_tokens < max_tokens pushed the budget below the model minimum.
// Try raising max_tokens to fit the original budget.
needed := budgetTokens + 1
maxAllowed := 0
if modelInfo != nil {
maxAllowed = modelInfo.MaxCompletionTokens
}
if maxAllowed > 0 && needed > maxAllowed {
// Cannot use original budget; cap max_tokens at model limit.
needed = maxAllowed
}
cappedBudget := needed - 1
if cappedBudget < minBudget {
// Impossible to satisfy both budget >= minBudget and budget < max_tokens
// within the model's completion limit. Disable thinking entirely.
body, _ = sjson.DeleteBytes(body, "thinking")
return body
}
body, _ = sjson.SetBytes(body, "max_tokens", needed)
if cappedBudget != budgetTokens {
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", cappedBudget)
}
return body
}

View File

@@ -0,0 +1,99 @@
package claude
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
)
func TestNormalizeClaudeBudget_RaisesMaxTokens(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":1000,"thinking":{"type":"enabled","budget_tokens":5000}}`)
out := a.normalizeClaudeBudget(body, 5000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 5001 {
t.Fatalf("max_tokens = %d, want 5001, body=%s", maxTok, string(out))
}
}
func TestNormalizeClaudeBudget_ClampsToModelMax(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":200000}}`)
out := a.normalizeClaudeBudget(body, 200000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 64000 {
t.Fatalf("max_tokens = %d, want 64000 (capped to model limit), body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 63999 {
t.Fatalf("budget_tokens = %d, want 63999 (max_tokens-1), body=%s", budget, string(out))
}
}
func TestNormalizeClaudeBudget_DisablesThinkingWhenUnsatisfiable(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 1000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":2000}}`)
out := a.normalizeClaudeBudget(body, 2000, modelInfo)
if gjson.GetBytes(out, "thinking").Exists() {
t.Fatalf("thinking should be removed when constraints are unsatisfiable, body=%s", string(out))
}
}
func TestNormalizeClaudeBudget_NoClamping(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":32000,"thinking":{"type":"enabled","budget_tokens":16000}}`)
out := a.normalizeClaudeBudget(body, 16000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 32000 {
t.Fatalf("max_tokens should remain 32000, got %d, body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 16000 {
t.Fatalf("budget_tokens should remain 16000, got %d, body=%s", budget, string(out))
}
}
func TestNormalizeClaudeBudget_AdjustsBudgetToMaxMinus1(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 8192,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":8192,"thinking":{"type":"enabled","budget_tokens":10000}}`)
out := a.normalizeClaudeBudget(body, 10000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 8192 {
t.Fatalf("max_tokens = %d, want 8192 (unchanged), body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 8191 {
t.Fatalf("budget_tokens = %d, want 8191 (max_tokens-1), body=%s", budget, string(out))
}
}

View File

@@ -26,6 +26,9 @@ type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
ThinkingBlockOpen bool
ThinkingStopPending bool
ThinkingSignature string
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -44,7 +47,7 @@ type ConvertCodexResponseToClaudeParams struct {
//
// Returns:
// - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
@@ -52,7 +55,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
}
}
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) {
return [][]byte{}
}
@@ -60,9 +62,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON)
params := (*param).(*ConvertCodexResponseToClaudeParams)
if params.ThinkingBlockOpen && params.ThinkingStopPending {
switch rootResult.Get("type").String() {
case "response.content_part.added", "response.completed":
output = append(output, finalizeCodexThinkingBlock(params)...)
}
}
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
var template []byte
if typeStr == "response.created" {
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
@@ -70,43 +81,46 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
} else if typeStr == "response.reasoning_summary_part.added" {
if params.ThinkingBlockOpen && params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.ThinkingBlockOpen = true
params.ThinkingStopPending = false
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.reasoning_summary_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
params.ThinkingStopPending = true
if params.ThinkingSignature != "" {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
} else if typeStr == "response.content_part.added" {
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.output_text.delta" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.content_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.completed" {
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
p := params.HasToolCall
stopReason := rootResult.Get("response.stop_reason").String()
if p {
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
@@ -128,13 +142,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
output = append(output, finalizeCodexThinkingBlock(params)...)
params.HasToolCall = true
params.HasReceivedArgumentsDelta = false
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
@@ -146,37 +160,43 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if itemType == "reasoning" {
params.ThinkingSignature = itemResult.Get("encrypted_content").String()
if params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
}
} else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if itemType == "reasoning" {
if signature := itemResult.Get("encrypted_content").String(); signature != "" {
params.ThinkingSignature = signature
}
output = append(output, finalizeCodexThinkingBlock(params)...)
params.ThinkingSignature = ""
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
params.HasReceivedArgumentsDelta = true
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if !params.HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
@@ -191,15 +211,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Claude Code API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
@@ -230,6 +241,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
switch item.Get("type").String() {
case "reasoning":
thinkingBuilder := strings.Builder{}
signature := item.Get("encrypted_content").String()
if summary := item.Get("summary"); summary.Exists() {
if summary.IsArray() {
summary.ForEach(func(_, part gjson.Result) bool {
@@ -260,9 +272,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
}
}
if thinkingBuilder.Len() > 0 {
if thinkingBuilder.Len() > 0 || signature != "" {
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if signature != "" {
block, _ = sjson.SetBytes(block, "signature", signature)
}
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
case "message":
@@ -371,6 +386,30 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
func ClaudeTokenCount(_ context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}
func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte {
if !params.ThinkingBlockOpen {
return nil
}
output := make([]byte, 0, 256)
if params.ThinkingSignature != "" {
signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2)
}
contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`)
contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2)
params.BlockIndex++
params.ThinkingBlockOpen = false
params.ThinkingStopPending = false
return output
}

View File

@@ -0,0 +1,282 @@
package claude
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startFound := false
signatureDeltaFound := false
stopFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
switch data.Get("type").String() {
case "content_block_start":
if data.Get("content_block.type").String() == "thinking" {
startFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line)
}
}
case "content_block_delta":
if data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
if got := data.Get("delta.signature").String(); got != "enc_sig_123" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
case "content_block_stop":
stopFound = true
}
}
}
if !startFound {
t.Fatal("expected thinking content_block_start event")
}
if !signatureDeltaFound {
t.Fatal("expected signature_delta event for thinking block")
}
if !stopFound {
t.Fatal("expected content_block_stop event for thinking block")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
thinkingStartFound := false
thinkingStopFound := false
signatureDeltaFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
thinkingStartFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line)
}
}
if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 {
thinkingStopFound = true
}
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
}
}
}
if !thinkingStartFound {
t.Fatal("expected thinking content_block_start event")
}
if !thinkingStopFound {
t.Fatal("expected thinking content_block_stop event")
}
if signatureDeltaFound {
t.Fatal("did not expect signature_delta without encrypted_content")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startCount := 0
stopCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
startCount++
}
if data.Get("type").String() == "content_block_stop" {
stopCount++
}
}
}
if startCount != 2 {
t.Fatalf("expected 2 thinking block starts, got %d", startCount)
}
if stopCount != 1 {
t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 2 {
t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_early" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 1 {
t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
response := []byte(`{
"type":"response.completed",
"response":{
"id":"resp_123",
"model":"gpt-5",
"usage":{"input_tokens":10,"output_tokens":20},
"output":[
{
"type":"reasoning",
"encrypted_content":"enc_sig_nonstream",
"summary":[{"type":"summary_text","text":"internal reasoning"}]
},
{
"type":"message",
"content":[{"type":"output_text","text":"final answer"}]
}
]
}
}`)
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
parsed := gjson.ParseBytes(out)
thinking := parsed.Get("content.0")
if thinking.Get("type").String() != "thinking" {
t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw)
}
if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" {
t.Fatalf("expected signature to be preserved, got %q", got)
}
if got := thinking.Get("thinking").String(); got != "internal reasoning" {
t.Fatalf("unexpected thinking text: %q", got)
}
}

View File

@@ -20,12 +20,14 @@ type oaiToResponsesStateReasoning struct {
OutputIndex int
}
type oaiToResponsesState struct {
Seq int
ResponseID string
Created int64
Started bool
ReasoningID string
ReasoningIndex int
Seq int
ResponseID string
Created int64
Started bool
CompletionPending bool
CompletedEmitted bool
ReasoningID string
ReasoningIndex int
// aggregation buffers for response.output
// Per-output message text buffers by index
MsgTextBuf map[int]*strings.Builder
@@ -60,6 +62,141 @@ func emitRespEvent(event string, payload []byte) []byte {
return translatorcommon.SSEEventData(event, payload)
}
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
if len(st.MsgItemAdded) > 0 {
for i := range st.MsgItemAdded {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
return emitRespEvent("response.completed", completed)
}
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
@@ -90,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
return [][]byte{}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
if st.CompletionPending && !st.CompletedEmitted {
st.CompletedEmitted = true
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
}
return [][]byte{}
}
@@ -165,6 +306,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.TotalTokens = 0
st.ReasoningTokens = 0
st.UsageSeen = false
st.CompletionPending = false
st.CompletedEmitted = false
// response.created
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
@@ -374,8 +517,9 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
}
// finish_reason triggers finalization, including text done/content done/item done,
// reasoning done/part.done, function args done/item done, and completed
// finish_reason triggers item-level finalization. response.completed is
// deferred until the terminal [DONE] marker so late usage-only chunks can
// still populate response.usage.
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
// Emit message done events for all indices that started a message
if len(st.MsgItemAdded) > 0 {
@@ -464,138 +608,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.FuncArgsDone[key] = true
}
}
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
// Build response.output using aggregated buffers
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
if len(st.MsgItemAdded) > 0 {
for i := range st.MsgItemAdded {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
out = append(out, emitRespEvent("response.completed", completed))
st.CompletionPending = true
}
return true

View File

@@ -24,6 +24,120 @@ func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Res
return event, gjson.Parse(dataLine)
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) {
t.Parallel()
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
tests := []struct {
name string
in []string
doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted.
hasUsage bool
inputTokens int64
outputTokens int64
totalTokens int64
}{
{
// A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI),
// so response.completed must wait for [DONE] to include that usage.
name: "late usage after finish reason",
in: []string{
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 3,
hasUsage: true,
inputTokens: 11,
outputTokens: 7,
totalTokens: 18,
},
{
// When usage arrives on the same chunk as finish_reason, we still expect a
// single response.completed event and it should remain deferred until [DONE].
name: "usage on finish reason chunk",
in: []string{
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: true,
inputTokens: 13,
outputTokens: 5,
totalTokens: 18,
},
{
// An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should
// still wait for [DONE] but omit the usage object entirely.
name: "no usage chunk",
in: []string{
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
completedCount := 0
completedInputIndex := -1
var completedData gjson.Result
// Reuse converter state across input lines to simulate one streaming response.
var param any
for i, line := range tt.in {
// One upstream chunk can emit multiple downstream SSE events.
for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param) {
event, data := parseOpenAIResponsesSSEEvent(t, chunk)
if event != "response.completed" {
continue
}
completedCount++
completedInputIndex = i
completedData = data
if i < tt.doneInputIndex {
t.Fatalf("unexpected early response.completed on input index %d", i)
}
}
}
if completedCount != 1 {
t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount)
}
if completedInputIndex != tt.doneInputIndex {
t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex)
}
// Missing upstream usage should stay omitted in the final completed event.
if !tt.hasUsage {
if completedData.Get("response.usage").Exists() {
t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw)
}
return
}
// When usage is present, the final response.completed event must preserve the usage values.
if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens {
t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens)
}
if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens {
t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens)
}
if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens {
t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens)
}
})
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
in := []string{
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
@@ -31,6 +145,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCalls
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -131,6 +246,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCa
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -213,6 +329,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndTo
in := []string{
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -261,6 +378,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneA
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)

View File

@@ -6,6 +6,7 @@ package handlers
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
@@ -493,6 +494,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
opts.Metadata = reqMeta
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
@@ -539,6 +541,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
opts.Metadata = reqMeta
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
@@ -589,6 +592,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
opts.Metadata = reqMeta
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
errChan := make(chan *interfaces.ErrorMessage, 1)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
@@ -698,7 +702,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
chunks = retryResult.Chunks
continue outer
}
streamErr = retryErr
streamErr = enrichAuthSelectionError(retryErr, providers, normalizedModel)
}
}
@@ -841,6 +845,54 @@ func replaceHeader(dst http.Header, src http.Header) {
}
}
func enrichAuthSelectionError(err error, providers []string, model string) error {
if err == nil {
return nil
}
var authErr *coreauth.Error
if !errors.As(err, &authErr) || authErr == nil {
return err
}
code := strings.TrimSpace(authErr.Code)
if code != "auth_not_found" && code != "auth_unavailable" {
return err
}
providerText := strings.Join(providers, ",")
if providerText == "" {
providerText = "unknown"
}
modelText := strings.TrimSpace(model)
if modelText == "" {
modelText = "unknown"
}
baseMessage := strings.TrimSpace(authErr.Message)
if baseMessage == "" {
baseMessage = "no auth available"
}
detail := fmt.Sprintf("%s (providers=%s, model=%s)", baseMessage, providerText, modelText)
// Clarify the most common alias confusion between Anthropic route names and internal provider keys.
if strings.Contains(","+providerText+",", ",claude,") {
detail += "; check Claude auth/key session and cooldown state via /v0/management/auth-files"
}
status := authErr.HTTPStatus
if status <= 0 {
status = http.StatusServiceUnavailable
}
return &coreauth.Error{
Code: authErr.Code,
Message: detail,
Retryable: authErr.Retryable,
HTTPStatus: status,
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError

View File

@@ -5,10 +5,12 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
@@ -66,3 +68,46 @@ func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
}
}
func TestEnrichAuthSelectionError_DefaultsTo503WithContext(t *testing.T) {
in := &coreauth.Error{Code: "auth_not_found", Message: "no auth available"}
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
var got *coreauth.Error
if !errors.As(out, &got) || got == nil {
t.Fatalf("expected coreauth.Error, got %T", out)
}
if got.StatusCode() != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusServiceUnavailable)
}
if !strings.Contains(got.Message, "providers=claude") {
t.Fatalf("message missing provider context: %q", got.Message)
}
if !strings.Contains(got.Message, "model=claude-sonnet-4-6") {
t.Fatalf("message missing model context: %q", got.Message)
}
if !strings.Contains(got.Message, "/v0/management/auth-files") {
t.Fatalf("message missing management hint: %q", got.Message)
}
}
func TestEnrichAuthSelectionError_PreservesExplicitStatus(t *testing.T) {
in := &coreauth.Error{Code: "auth_unavailable", Message: "no auth available", HTTPStatus: http.StatusTooManyRequests}
out := enrichAuthSelectionError(in, []string{"gemini"}, "gemini-2.5-pro")
var got *coreauth.Error
if !errors.As(out, &got) || got == nil {
t.Fatalf("expected coreauth.Error, got %T", out)
}
if got.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusTooManyRequests)
}
}
func TestEnrichAuthSelectionError_IgnoresOtherErrors(t *testing.T) {
in := errors.New("boom")
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
if out != in {
t.Fatalf("expected original error to be returned unchanged")
}
}

View File

@@ -2,10 +2,13 @@ package handlers
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -463,6 +466,76 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
}
}
func TestExecuteStreamWithAuthManager_EnrichesBootstrapRetryAuthUnavailableError(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
var gotErr *interfaces.ErrorMessage
for msg := range errChan {
if msg != nil {
gotErr = msg
}
}
if gotErr == nil {
t.Fatalf("expected terminal error")
}
if gotErr.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", gotErr.StatusCode, http.StatusServiceUnavailable)
}
var authErr *coreauth.Error
if !errors.As(gotErr.Error, &authErr) || authErr == nil {
t.Fatalf("expected coreauth.Error, got %T", gotErr.Error)
}
if authErr.Code != "auth_unavailable" {
t.Fatalf("code = %q, want %q", authErr.Code, "auth_unavailable")
}
if !strings.Contains(authErr.Message, "providers=codex") {
t.Fatalf("message missing provider context: %q", authErr.Message)
}
if !strings.Contains(authErr.Message, "model=test-model") {
t.Fatalf("message missing model context: %q", authErr.Message)
}
if executor.Calls() != 1 {
t.Fatalf("expected exactly one upstream call before retry path selection failure, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)

View File

@@ -234,6 +234,84 @@ func (m *Manager) RefreshSchedulerEntry(authID string) {
m.scheduler.upsertAuth(snapshot)
}
// ReconcileRegistryModelStates aligns per-model runtime state with the current
// registry snapshot for one auth.
//
// Supported models are reset to a clean state because re-registration already
// cleared the registry-side cooldown/suspension snapshot. ModelStates for
// models that are no longer present in the registry are pruned entirely so
// renamed/removed models cannot keep auth-level status stale.
func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID string) {
if m == nil || authID == "" {
return
}
supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID)
supported := make(map[string]struct{}, len(supportedModels))
for _, model := range supportedModels {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
supported[modelKey] = struct{}{}
}
var snapshot *Auth
now := time.Now()
m.mu.Lock()
auth, ok := m.auths[authID]
if ok && auth != nil && len(auth.ModelStates) > 0 {
changed := false
for modelKey, state := range auth.ModelStates {
baseModel := canonicalModelKey(modelKey)
if baseModel == "" {
baseModel = strings.TrimSpace(modelKey)
}
if _, supportedModel := supported[baseModel]; !supportedModel {
// Drop state for models that disappeared from the current registry
// snapshot. Keeping them around leaks stale errors into auth-level
// status, management output, and websocket fallback checks.
delete(auth.ModelStates, modelKey)
changed = true
continue
}
if state == nil {
continue
}
if modelStateIsClean(state) {
continue
}
resetModelState(state, now)
changed = true
}
if len(auth.ModelStates) == 0 {
auth.ModelStates = nil
}
if changed {
updateAggregatedAvailability(auth, now)
if !hasModelError(auth, now) {
auth.LastError = nil
auth.StatusMessage = ""
auth.Status = StatusActive
}
auth.UpdatedAt = now
if errPersist := m.persist(ctx, auth); errPersist != nil {
logEntryWithRequestID(ctx).WithField("auth_id", auth.ID).Warnf("failed to persist auth changes during model state reconciliation: %v", errPersist)
}
snapshot = auth.Clone()
}
}
m.mu.Unlock()
if m.scheduler != nil && snapshot != nil {
m.scheduler.upsertAuth(snapshot)
}
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
@@ -1838,6 +1916,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
} else {
if result.Model != "" {
if !isRequestScopedNotFoundResultError(result.Error) {
disableCooling := quotaCooldownDisabledForAuth(auth)
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
@@ -1858,31 +1937,45 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
}
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
}
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
}
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
if !disableCooling {
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, disableCooling)
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
@@ -1891,11 +1984,13 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
if !disableCooling {
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
}
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
@@ -1966,8 +2061,28 @@ func resetModelState(state *ModelState, now time.Time) {
state.UpdatedAt = now
}
func modelStateIsClean(state *ModelState) bool {
if state == nil {
return true
}
if state.Status != StatusActive {
return false
}
if state.Unavailable || state.StatusMessage != "" || !state.NextRetryAfter.IsZero() || state.LastError != nil {
return false
}
if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 {
return false
}
return true
}
func updateAggregatedAvailability(auth *Auth, now time.Time) {
if auth == nil || len(auth.ModelStates) == 0 {
if auth == nil {
return
}
if len(auth.ModelStates) == 0 {
clearAggregatedAvailability(auth)
return
}
allUnavailable := true
@@ -1975,10 +2090,12 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
quotaExceeded := false
quotaRecover := time.Time{}
maxBackoffLevel := 0
hasState := false
for _, state := range auth.ModelStates {
if state == nil {
continue
}
hasState = true
stateUnavailable := false
if state.Status == StatusDisabled {
stateUnavailable = true
@@ -2008,6 +2125,10 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
}
}
}
if !hasState {
clearAggregatedAvailability(auth)
return
}
auth.Unavailable = allUnavailable
if allUnavailable {
auth.NextRetryAfter = earliestRetry
@@ -2027,6 +2148,15 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
}
}
func clearAggregatedAvailability(auth *Auth) {
if auth == nil {
return
}
auth.Unavailable = false
auth.NextRetryAfter = time.Time{}
auth.Quota = QuotaState{}
}
func hasModelError(auth *Auth, now time.Time) bool {
if auth == nil || len(auth.ModelStates) == 0 {
return false
@@ -2211,6 +2341,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if isRequestScopedNotFoundResultError(resultErr) {
return
}
disableCooling := quotaCooldownDisabledForAuth(auth)
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
@@ -2224,32 +2355,46 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRetryAfter = now.Add(30 * time.Minute)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(30 * time.Minute)
}
case 402, 403:
auth.StatusMessage = "payment_required"
auth.NextRetryAfter = now.Add(30 * time.Minute)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(30 * time.Minute)
}
case 404:
auth.StatusMessage = "not_found"
auth.NextRetryAfter = now.Add(12 * time.Hour)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(12 * time.Hour)
}
case 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
var next time.Time
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
if !disableCooling {
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, disableCooling)
if cooldown > 0 {
next = now.Add(cooldown)
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.NextRecoverAt = next
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabledForAuth(auth) {
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)

View File

@@ -180,6 +180,34 @@ func (e *authFallbackExecutor) StreamCalls() []string {
return out
}
type retryAfterStatusError struct {
status int
message string
retryAfter time.Duration
}
func (e *retryAfterStatusError) Error() string {
if e == nil {
return ""
}
return e.message
}
func (e *retryAfterStatusError) StatusCode() int {
if e == nil {
return 0
}
return e.status
}
func (e *retryAfterStatusError) RetryAfter() *time.Duration {
if e == nil {
return nil
}
d := e.retryAfter
return &d
}
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
t.Helper()
@@ -450,6 +478,174 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
}
}
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride_On403(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-403",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-403"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
m.MarkResult(context.Background(), Result{
AuthID: auth.ID,
Provider: "claude",
Model: model,
Success: false,
Error: &Error{HTTPStatus: http.StatusForbidden, Message: "forbidden"},
})
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
if count := reg.GetModelCount(model); count <= 0 {
t.Fatalf("expected model count > 0 when disable_cooling=true, got %d", count)
}
}
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter403(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "claude",
executeErrors: map[string]error{
"auth-403-exec": &Error{
HTTPStatus: http.StatusForbidden,
Message: "forbidden",
},
},
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "auth-403-exec",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-403-exec"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
req := cliproxyexecutor.Request{Model: model}
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute1 == nil {
t.Fatal("expected first execute error")
}
if statusCodeFromError(errExecute1) != http.StatusForbidden {
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusForbidden)
}
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute2 == nil {
t.Fatal("expected second execute error")
}
if statusCodeFromError(errExecute2) != http.StatusForbidden {
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusForbidden)
}
}
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter429RetryAfter(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "claude",
executeErrors: map[string]error{
"auth-429-exec": &retryAfterStatusError{
status: http.StatusTooManyRequests,
message: "quota exhausted",
retryAfter: 2 * time.Minute,
},
},
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "auth-429-exec",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-429-exec"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
req := cliproxyexecutor.Request{Model: model}
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute1 == nil {
t.Fatal("expected first execute error")
}
if statusCodeFromError(errExecute1) != http.StatusTooManyRequests {
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusTooManyRequests)
}
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute2 == nil {
t.Fatal("expected second execute error")
}
if statusCodeFromError(errExecute2) != http.StatusTooManyRequests {
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusTooManyRequests)
}
calls := executor.ExecuteCalls()
if len(calls) != 2 {
t.Fatalf("execute calls = %d, want 2", len(calls))
}
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
m := NewManager(nil, nil, nil)

View File

@@ -324,6 +324,7 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
s.coreManager.ReconcileRegistryModelStates(ctx, auth.ID)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
@@ -1085,6 +1086,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
s.ensureExecutorsForAuth(current)
}
s.registerModelsForAuth(current)
s.coreManager.ReconcileRegistryModelStates(context.Background(), current.ID)
latest, ok := s.latestAuthForModelRegistration(current.ID)
if !ok || latest.Disabled {
@@ -1098,6 +1100,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
// no auth fields changed, but keeps the refresh path simple and correct.
s.ensureExecutorsForAuth(latest)
s.registerModelsForAuth(latest)
s.coreManager.ReconcileRegistryModelStates(context.Background(), latest.ID)
s.coreManager.RefreshSchedulerEntry(current.ID)
return true
}

View File

@@ -58,7 +58,7 @@ func Parse(raw string) (Setting, error) {
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
case "socks5", "socks5h", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
@@ -95,7 +95,7 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
if setting.URL.Scheme == "socks5" || setting.URL.Scheme == "socks5h" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()

View File

@@ -30,6 +30,7 @@ func TestParse(t *testing.T) {
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
{name: "socks5h", input: "socks5h://proxy.example.com:1080", want: ModeProxy},
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
}
@@ -137,3 +138,24 @@ func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testin
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
}
}
func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("socks5h://proxy.example.com:1080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected SOCKS5H transport to bypass http proxy function")
}
if transport.DialContext == nil {
t.Fatal("expected SOCKS5H transport to have custom DialContext")
}
}

View File

@@ -0,0 +1,106 @@
package test
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
type jsonObject = map[string]any
func loadClaudeCodeSentinelFixture(t *testing.T, name string) jsonObject {
t.Helper()
path := filepath.Join("testdata", "claude_code_sentinels", name)
data := mustReadFile(t, path)
var payload jsonObject
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("unmarshal %s: %v", name, err)
}
return payload
}
func mustReadFile(t *testing.T, path string) []byte {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read %s: %v", path, err)
}
return data
}
func requireStringField(t *testing.T, obj jsonObject, key string) string {
t.Helper()
value, ok := obj[key].(string)
if !ok || value == "" {
t.Fatalf("field %q missing or empty: %#v", key, obj[key])
}
return value
}
func TestClaudeCodeSentinel_ToolProgressShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "tool_progress.json")
if got := requireStringField(t, payload, "type"); got != "tool_progress" {
t.Fatalf("type = %q, want tool_progress", got)
}
requireStringField(t, payload, "tool_use_id")
requireStringField(t, payload, "tool_name")
requireStringField(t, payload, "session_id")
if _, ok := payload["elapsed_time_seconds"].(float64); !ok {
t.Fatalf("elapsed_time_seconds missing or non-number: %#v", payload["elapsed_time_seconds"])
}
}
func TestClaudeCodeSentinel_SessionStateShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "session_state_changed.json")
if got := requireStringField(t, payload, "type"); got != "system" {
t.Fatalf("type = %q, want system", got)
}
if got := requireStringField(t, payload, "subtype"); got != "session_state_changed" {
t.Fatalf("subtype = %q, want session_state_changed", got)
}
state := requireStringField(t, payload, "state")
switch state {
case "idle", "running", "requires_action":
default:
t.Fatalf("unexpected session state %q", state)
}
requireStringField(t, payload, "session_id")
}
func TestClaudeCodeSentinel_ToolUseSummaryShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "tool_use_summary.json")
if got := requireStringField(t, payload, "type"); got != "tool_use_summary" {
t.Fatalf("type = %q, want tool_use_summary", got)
}
requireStringField(t, payload, "summary")
rawIDs, ok := payload["preceding_tool_use_ids"].([]any)
if !ok || len(rawIDs) == 0 {
t.Fatalf("preceding_tool_use_ids missing or empty: %#v", payload["preceding_tool_use_ids"])
}
for i, raw := range rawIDs {
if id, ok := raw.(string); !ok || id == "" {
t.Fatalf("preceding_tool_use_ids[%d] invalid: %#v", i, raw)
}
}
}
func TestClaudeCodeSentinel_ControlRequestCanUseToolShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "control_request_can_use_tool.json")
if got := requireStringField(t, payload, "type"); got != "control_request" {
t.Fatalf("type = %q, want control_request", got)
}
requireStringField(t, payload, "request_id")
request, ok := payload["request"].(map[string]any)
if !ok {
t.Fatalf("request missing or invalid: %#v", payload["request"])
}
if got := requireStringField(t, request, "subtype"); got != "can_use_tool" {
t.Fatalf("request.subtype = %q, want can_use_tool", got)
}
requireStringField(t, request, "tool_name")
requireStringField(t, request, "tool_use_id")
if input, ok := request["input"].(map[string]any); !ok || len(input) == 0 {
t.Fatalf("request.input missing or empty: %#v", request["input"])
}
}

View File

@@ -0,0 +1,11 @@
{
"type": "control_request",
"request_id": "req_123",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": {"command": "npm test"},
"tool_use_id": "toolu_123",
"description": "Running npm test"
}
}

View File

@@ -0,0 +1,7 @@
{
"type": "system",
"subtype": "session_state_changed",
"state": "requires_action",
"uuid": "22222222-2222-4222-8222-222222222222",
"session_id": "sess_123"
}

View File

@@ -0,0 +1,10 @@
{
"type": "tool_progress",
"tool_use_id": "toolu_123",
"tool_name": "Bash",
"parent_tool_use_id": null,
"elapsed_time_seconds": 2.5,
"task_id": "task_123",
"uuid": "11111111-1111-4111-8111-111111111111",
"session_id": "sess_123"
}

View File

@@ -0,0 +1,7 @@
{
"type": "tool_use_summary",
"summary": "Searched in auth/",
"preceding_tool_use_ids": ["toolu_1", "toolu_2"],
"uuid": "33333333-3333-4333-8333-333333333333",
"session_id": "sess_123"
}